def dev(config, bert_config, dev_path, id2rel, tokenizer, output_path=None):
    dev_data = json.load(open(dev_path))
    for sent in dev_data:
        data.to_tuple(sent)
    with torch.no_grad():
        Bert_model = BertModel(bert_config).to(device).eval()
        submodel = sub_model(config).to(device).eval()
        objmodel = obj_model(config).to(device).eval()

        state = torch.load(
            os.path.join(config.output_dir, config.load_model_name))
        Bert_model.load_state_dict(state['bert_state_dict'])
        submodel.load_state_dict(state['subject_state_dict'])
        objmodel.load_state_dict(state['object_state_dict'])

        precision, recall, f1 = utils.metric(Bert_model,
                                             submodel,
                                             objmodel,
                                             dev_data,
                                             id2rel,
                                             tokenizer,
                                             output_path=output_path)
        logger.info('precision: %.4f' % precision)
        logger.info('recall: %.4f' % recall)
        logger.info('F1: %.4f' % f1)
Beispiel #2
0
def init_bert_model_with_teacher(
    student: BertModel,
    teacher: BertModel,
    layers_to_transfer: List[int] = None,
) -> BertModel:
    """Initialize student model with teacher layers.

    Args:
        student (BertModel): Student model.
        teacher (BertModel): Teacher model.
        layers_to_transfer (List[int], optional): Defines which layers will be transfered.
            If None then will transfer last layers. Defaults to None.

    Returns:
        BertModel: [description]
    """
    teacher_hidden_size = teacher.config.hidden_size
    student_hidden_size = student.config.hidden_size
    if teacher_hidden_size != student_hidden_size:
        raise Exception("Teacher and student hidden size should be the same")
    teacher_layers_num = teacher.config.num_hidden_layers
    student_layers_num = student.config.num_hidden_layers

    if layers_to_transfer is None:
        layers_to_transfer = list(
            range(teacher_layers_num - student_layers_num, teacher_layers_num))

    prefix_teacher = list(teacher.state_dict().keys())[0].split(".")[0]
    prefix_student = list(student.state_dict().keys())[0].split(".")[0]
    student_sd = _extract_layers(
        teacher_model=teacher,
        layers=layers_to_transfer,
    )
    student.load_state_dict(student_sd)
    return student
Beispiel #3
0
def get_kobert_model(model_file, vocab_file, ctx="cpu"):
    bertmodel = BertModel(config=BertConfig.from_dict(bert_config))
    bertmodel.load_state_dict(torch.load(model_file), strict=False)
    device = torch.device(ctx)
    bertmodel.to(device)
    bertmodel.eval()
    vocab_b_obj = nlp.vocab.BERTVocab.from_json(open(vocab_file, 'rt').read())
    return bertmodel, vocab_b_obj
Beispiel #4
0
def get_kobert_model(model_file, vocab_file, ctx="cpu"):
    bertmodel = BertModel(config=BertConfig.from_dict(bert_config))
    bertmodel.load_state_dict(torch.load(model_file))
    device = torch.device(ctx)
    bertmodel.to(device)
    bertmodel.eval()
    vocab_b_obj = nlp.vocab.BERTVocab.from_sentencepiece(vocab_file,
                                                         padding_token='[PAD]')
    return bertmodel, vocab_b_obj
Beispiel #5
0
def run(pretrained_model, out_dir, num_layers=3):
    os.makedirs(out_dir, exist_ok=True)

    tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
    model = BertModel.from_pretrained(pretrained_model, return_dict=True)

    small_config = copy.deepcopy(model.config)
    small_config.num_hidden_layers = num_layers
    small_model = BertModel(small_config)
    small_model.load_state_dict(model.state_dict(), strict=False)

    tokenizer.save_pretrained(out_dir)
    small_model.save_pretrained(out_dir)
class BioBert(nn.Module):
    def __init__(self, num_labels, config, state_dict):
        super().__init__()
        self.bert = BertModel(config)
        self.bert.load_state_dict(state_dict, strict=False)
        self.dropout = nn.Dropout(p=0.3)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, input_ids, attention_mask):
        #https://huggingface.co/transformers/model_doc/bert.html#bertmodel
        # last_hidden_state: Sequence of hidden-states at the output of the last layer of the model.
        # pooler_output: Last layer hidden-state of the first token of the sequence (classification token) further processed by a Linear layer and a Tanh activation function.
        last_hidden_state, pooler_output = self.bert(
            input_ids=input_ids, attention_mask=attention_mask)
        output = self.dropout(pooler_output)
        out = self.classifier(output)
        return out
class BioBertNER(nn.Module):

  def __init__(self, num_labels, config, state_dict):
    super().__init__()
    self.bert = BertModel(config)
    self.bert.load_state_dict(state_dict, strict=False)
    self.dropout = nn.Dropout(p=0.3)
    self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
    self.softmax = nn.Softmax(dim=1)
    
  def forward(self, input_ids, attention_mask):
    encoded_layer, pooled_output = self.bert(input_ids=input_ids,
      attention_mask=attention_mask)
    enlayer = encoded_layer[-1]
    enlayer = self.dropout(enlayer)
    outlayer = self.classifier(enlayer)
    pooled_output = self.dropout(pooled_output)
    out = self.classifier(pooled_output)
    return out, outlayer
def get_bert(BERT_PT_PATH, bert_type, do_lower_case, no_pretraining):

    bert_config_file = os.path.join(BERT_PT_PATH,
                                    f'bert_config_{bert_type}.json')
    vocab_file = os.path.join(BERT_PT_PATH, f'vocab_{bert_type}.txt')
    init_checkpoint = os.path.join(BERT_PT_PATH,
                                   f'pytorch_model_{bert_type}.bin')

    bert_config = BertConfig.from_json_file(bert_config_file)
    tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file,
                                           do_lower_case=do_lower_case)
    bert_config.print_status()

    model_bert = BertModel(bert_config)
    if no_pretraining:
        pass
    else:
        model_bert.load_state_dict(
            torch.load(init_checkpoint, map_location='cpu'))
        print("Load pre-trained parameters.")
    model_bert.to(device)

    return model_bert, tokenizer, bert_config
Beispiel #9
0
    def _load_bert(self, bert_config_path: str, bert_model_path: str):
        bert_config = BertConfig.from_json_file(bert_config_path)
        model = BertModel(bert_config)
        if self.cuda:
            model_states = torch.load(bert_model_path)
        else:
            model_states = torch.load(bert_model_path, map_location='cpu')
        # fix model_states
        for k in list(model_states.keys()):
            if k.startswith("bert."):
                model_states[k[5:]] = model_states.pop(k)
            elif k.startswith("cls"):
                _ = model_states.pop(k)

            if k[-4:] == "beta":
                model_states[k[:-4]+"bias"] = model_states.pop(k)
            if k[-5:] == "gamma":
                model_states[k[:-5]+"weight"] = model_states.pop(k)

        model.load_state_dict(model_states)
        if self.cuda:
            model.cuda()
        model.eval()
        return model
Beispiel #10
0
def get_pretrained_model(path, logger, args=None):
    logger.info('load pretrained model in {}'.format(path))
    bert_tokenizer = BertTokenizer.from_pretrained(path)
    
    if args is None or args.hidden_layers == 12:
        bert_config = BertConfig.from_pretrained(path)
        bert_model = BertModel.from_pretrained(path)

    else:
        logger.info('load {} layers bert'.format(args.hidden_layers))
        bert_config = BertConfig.from_pretrained(path, num_hidden_layers=args.hidden_layers)
        bert_model = BertModel(bert_config)
        model_param_list = [p[0] for p in bert_model.named_parameters()]
        load_dict = torch.load(os.path.join(path, 'pytorch_model.bin'))
        new_load_dict = {}
        for k, v in load_dict.items():
            k = k.replace('bert.', '')
            if k in model_param_list:
                new_load_dict[k] = v
        new_load_dict['embeddings.position_ids'] = torch.tensor([i for i in range(512)]).unsqueeze(dim=0)
        bert_model.load_state_dict(new_load_dict)

    logger.info('load complete')
    return bert_config, bert_tokenizer, bert_model
Beispiel #11
0
class RenamingModelHybrid(nn.Module):
    def __init__(self, vocab, top_k, config, device):
        super(RenamingModelHybrid, self).__init__()

        self.vocab = vocab
        self.top_k = top_k
        self.source_vocab_size = len(self.vocab.source_tokens) + 1

        self.graph_encoder = GraphASTEncoder.build(
            config['encoder']['graph_encoder'])
        self.graph_emb_size = config['encoder']['graph_encoder']['gnn'][
            'hidden_size']
        self.emb_size = 256

        state_dict = torch.load(
            'saved_checkpoints/bert_2604/bert_pretrained_epoch_23_batch_140000.pth',
            map_location=device)

        keys_to_delete = [
            "cls.predictions.bias", "cls.predictions.transform.dense.weight",
            "cls.predictions.transform.dense.bias",
            "cls.predictions.transform.LayerNorm.weight",
            "cls.predictions.transform.LayerNorm.bias",
            "cls.predictions.decoder.weight", "cls.predictions.decoder.bias",
            "cls.seq_relationship.weight", "cls.seq_relationship.bias"
        ]

        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict['model'].items():
            if k in keys_to_delete: continue
            name = k[5:]  # remove `bert.`
            new_state_dict[name] = v

        bert_config = BertConfig(vocab_size=self.source_vocab_size,
                                 max_position_embeddings=512,
                                 num_hidden_layers=6,
                                 hidden_size=self.emb_size,
                                 num_attention_heads=4)
        self.bert_encoder = BertModel(bert_config)
        self.bert_encoder.load_state_dict(new_state_dict)

        self.target_vocab_size = len(self.vocab.all_subtokens) + 1

        bert_config = BertConfig(vocab_size=self.target_vocab_size,
                                 max_position_embeddings=1000,
                                 num_hidden_layers=6,
                                 hidden_size=self.emb_size,
                                 num_attention_heads=4,
                                 is_decoder=True)
        self.bert_decoder = BertModel(bert_config)

        state_dict = torch.load(
            'saved_checkpoints/bert_0905/bert_decoder_epoch_19_batch_220000.pth',
            map_location=device)

        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict['model'].items():
            if k in keys_to_delete: continue
            if 'crossattention' in k: continue
            name = k[5:]  # remove `bert.`
            new_state_dict[name] = v

        for key in new_state_dict:
            self.bert_decoder.state_dict()[key].copy_(new_state_dict[key])

        self.enc_graph_map = nn.Linear(self.emb_size + self.graph_emb_size,
                                       self.emb_size)
        self.fc_final = nn.Linear(self.emb_size, self.target_vocab_size)

        self.fc_final.weight.data = state_dict['model'][
            'cls.predictions.decoder.weight']

    def forward(self, src_tokens, src_mask, variable_ids, target_tokens,
                graph_input):
        encoder_attention_mask = torch.ones_like(src_tokens).float().to(
            src_tokens.device)
        encoder_attention_mask[src_tokens == PAD_ID] = 0.0

        assert torch.max(src_tokens) < self.source_vocab_size
        assert torch.min(src_tokens) >= 0
        assert torch.max(target_tokens) < self.target_vocab_size
        assert torch.min(target_tokens) >= 0

        encoder_output = self.bert_encoder(
            input_ids=src_tokens, attention_mask=encoder_attention_mask)[0]

        graph_output = self.graph_encoder(graph_input)
        variable_emb = graph_output['variable_encoding']

        graph_embedding = torch.gather(
            variable_emb, 1,
            variable_ids.unsqueeze(2).repeat(
                1, 1, variable_emb.shape[2])) * src_mask.unsqueeze(2)

        full_enc_output = self.enc_graph_map(
            torch.cat((encoder_output, graph_embedding), dim=2))

        decoder_attention_mask = torch.ones_like(target_tokens).float().to(
            target_tokens.device)
        decoder_attention_mask[target_tokens == PAD_ID] = 0.0

        decoder_output = self.bert_decoder(
            input_ids=target_tokens,
            attention_mask=decoder_attention_mask,
            encoder_hidden_states=full_enc_output,
            encoder_attention_mask=encoder_attention_mask)[0]

        predictions = self.fc_final(decoder_output)

        return predictions

    def predict(self,
                src_tokens,
                src_mask,
                variable_ids,
                graph_input,
                approx=False):
        end_token = self.vocab.all_subtokens.word2id['</s>']
        start_token = self.vocab.all_subtokens.word2id['<s>']
        batch_size = src_tokens.shape[0]

        encoder_attention_mask = torch.ones_like(src_tokens).float().to(
            src_tokens.device)
        encoder_attention_mask[src_tokens == PAD_ID] = 0.0

        assert torch.max(src_tokens) < self.source_vocab_size
        assert torch.min(src_tokens) >= 0

        encoder_output = self.bert_encoder(
            input_ids=src_tokens, attention_mask=encoder_attention_mask)[0]

        graph_output = self.graph_encoder(graph_input)
        variable_emb = graph_output['variable_encoding']

        graph_embedding = torch.gather(
            variable_emb, 1,
            variable_ids.unsqueeze(2).repeat(
                1, 1, variable_emb.shape[2])) * src_mask.unsqueeze(2)

        full_enc_output = self.enc_graph_map(
            torch.cat((encoder_output, graph_embedding), dim=2))

        source_vocab_to_target = {
            self.vocab.source_tokens.word2id[t]:
            self.vocab.all_subtokens.word2id[t]
            for t in self.vocab.source_tokens.word2id.keys()
        }
        src_target_maps = []
        confidences = []

        for i in range(batch_size):

            if src_tokens[i][0] != start_token:
                input_sequence = torch.zeros(src_tokens.shape[1] + 1).to(
                    src_tokens.device)
                input_mask = torch.zeros(src_mask.shape[1] + 1).to(
                    src_mask.device)
                input_sequence[1:] = src_tokens[i]
                input_mask[1:] = src_mask[i]
            else:
                input_sequence = src_tokens[i]
                input_mask = src_mask[i]

            num_vars = int(input_mask.sum())
            seq_len = torch.sum((input_sequence != PAD_ID).long())
            generated_seqs = torch.zeros(1, min(
                seq_len + 10 * num_vars, 1000)).long().to(src_tokens.device)

            source_marker = 0
            gen_markers = torch.LongTensor([0]).to(generated_seqs.device)
            prior_probs = torch.FloatTensor([0]).to(generated_seqs.device)

            candidate_maps = [{}]

            for _ in range(num_vars):
                # Filling up the known (non-identifier) tokens
                while source_marker < seq_len and input_mask[
                        source_marker] != 1:
                    token = input_sequence[source_marker]
                    values = source_vocab_to_target[token.item(
                    )] * torch.ones_like(gen_markers).to(generated_seqs.device)

                    generated_seqs = torch.scatter(generated_seqs, 1,
                                                   gen_markers.unsqueeze(1),
                                                   values.unsqueeze(1))

                    source_marker += 1
                    gen_markers += 1

                if source_marker >= seq_len: break

                curr_var = input_sequence[source_marker].item()

                if curr_var in candidate_maps[0]:
                    if approx is True:
                        source_marker += 1
                        continue
                    # If we've seen this variable before, just use the previous predictions and update the scores
                    # Note - it's enough to check candidate_maps[0] because if it is in the first map, it is in all of them

                    orig_markers = gen_markers.clone()

                    for j in range(len(candidate_maps)):
                        pred = candidate_maps[j][curr_var]
                        generated_seqs[j][gen_markers[j]:gen_markers[j] +
                                          len(pred)] = torch.LongTensor(
                                              pred).to(generated_seqs.device)
                        gen_markers[j] += len(pred)

                    decoder_attention_mask = torch.ones_like(
                        generated_seqs).float().to(generated_seqs.device)
                    decoder_attention_mask[generated_seqs == PAD_ID] = 0.0

                    decoder_output = self.bert_decoder(
                        input_ids=generated_seqs,
                        attention_mask=decoder_attention_mask,
                        encoder_hidden_states=full_enc_output[i].unsqueeze(0),
                        encoder_attention_mask=encoder_attention_mask[i].
                        unsqueeze(0))[0]

                    probabilities = F.log_softmax(
                        self.fc_final(decoder_output), dim=-1)

                    # Add up the scores of the token at the __next__ time step

                    scores = torch.zeros(generated_seqs.shape[0]).to(
                        generated_seqs.device)
                    active = torch.ones(generated_seqs.shape[0]).long().to(
                        generated_seqs.device)
                    temp_markers = orig_markers

                    while torch.sum(active) != 0:
                        position_probs = torch.gather(
                            probabilities, 1,
                            (temp_markers - 1).reshape(-1, 1, 1).repeat(
                                1, 1, probabilities.shape[2])).squeeze(1)
                        curr_tokens = torch.gather(generated_seqs, 1,
                                                   temp_markers.unsqueeze(1))
                        tok_probs = torch.gather(position_probs, 1,
                                                 curr_tokens).squeeze(1)

                        tok_probs *= active
                        scores += tok_probs

                        active *= (temp_markers != (gen_markers - 1)).long()
                        temp_markers += active

                    # Update the prior probabilities
                    prior_probs = prior_probs + scores

                else:
                    # You encounter a new variable which hasn't been seen before
                    # Generate <beam_width> possibilities for its name
                    generated_seqs, gen_markers, prior_probs, candidate_maps = self.beam_search(
                        generated_seqs,
                        gen_markers,
                        prior_probs,
                        candidate_maps,
                        curr_var,
                        full_enc_output[i].unsqueeze(0),
                        encoder_attention_mask[i].unsqueeze(0),
                        beam_width=5,
                        top_k=self.top_k)

                source_marker += 1

            final_ind = torch.argmax(prior_probs)
            confidence = torch.max(prior_probs).item()
            src_target_map = candidate_maps[final_ind]

            src_target_maps.append(src_target_map)
            confidences.append(confidence)

        return src_target_maps, confidences

    def beam_search(self,
                    generated_seqs,
                    gen_markers,
                    prior_probs,
                    candidate_maps,
                    curr_var,
                    full_enc_output,
                    encoder_attention_mask,
                    beam_width=5,
                    top_k=10):

        if generated_seqs.shape[0] * beam_width < top_k:
            beam_width = top_k

        active = torch.ones_like(gen_markers).to(gen_markers.device)
        beam_alpha = 0.7
        end_token = self.vocab.all_subtokens.word2id['</s>']

        candidate_maps = candidate_maps
        orig_markers = gen_markers.clone()

        for _ in range(10):  # Predict at most 10 subtokens
            decoder_attention_mask = torch.ones_like(
                generated_seqs).float().to(generated_seqs.device)
            decoder_attention_mask[generated_seqs == PAD_ID] = 0.0

            decoder_output = self.bert_decoder(
                input_ids=generated_seqs,
                attention_mask=decoder_attention_mask,
                encoder_hidden_states=full_enc_output,
                encoder_attention_mask=encoder_attention_mask)[0]
            probabilities = F.log_softmax(self.fc_final(decoder_output),
                                          dim=-1)
            # Gather the predictions at the current markers
            # (gen_marker - 1) because prediction happens one step ahead
            probabilities = torch.gather(
                probabilities, 1, (gen_markers - 1).reshape(-1, 1, 1).repeat(
                    1, 1, probabilities.shape[2])).squeeze(1)

            probs, preds = probabilities.sort(dim=-1, descending=True)

            probs *= active.unsqueeze(
                1)  # Set log prob of non-active ones to 0
            preds[
                active ==
                0] = end_token  # Set preds of non-active ones to the end token (ie, remain unchanged)

            # Repeat active ones only once. Repeat the rest beam_width no. of times.
            filter_mask = torch.ones(
                (preds.shape[0], beam_width)).long().to(preds.device)
            filter_mask *= active.unsqueeze(1)
            filter_mask[:, 0][active == 0] = 1
            filter_mask = filter_mask.reshape(-1)

            preds = preds[:, :beam_width].reshape(-1)[filter_mask == 1]
            probs = probs[:, :beam_width].reshape(-1)[filter_mask == 1]

            generated_seqs = torch.repeat_interleave(generated_seqs,
                                                     beam_width,
                                                     dim=0)[filter_mask == 1]
            orig_markers = torch.repeat_interleave(orig_markers,
                                                   beam_width,
                                                   dim=0)[filter_mask == 1]
            gen_markers = torch.repeat_interleave(gen_markers,
                                                  beam_width,
                                                  dim=0)[filter_mask == 1]
            active = torch.repeat_interleave(active, beam_width,
                                             dim=0)[filter_mask == 1]
            prior_probs = torch.repeat_interleave(prior_probs,
                                                  beam_width,
                                                  dim=0)[filter_mask == 1]

            candidate_maps = [
                item.copy() for item in candidate_maps
                for _ in range(beam_width)
            ]
            candidate_maps = [
                candidate_maps[i] for i in range(len(candidate_maps))
                if filter_mask[i] == 1
            ]

            generated_seqs.scatter_(1, gen_markers.unsqueeze(1),
                                    preds.unsqueeze(1))

            # lengths       = (gen_markers - gen_marker + 1).float()
            # penalties     = torch.pow(5 + lengths, beam_alpha) / math.pow(6, beam_alpha)
            penalties = torch.ones_like(probs).to(probs.device)

            updated_probs = probs + prior_probs

            sort_inds = (updated_probs / penalties).argsort(descending=True)
            updated_probs = updated_probs[sort_inds]

            prior_probs = updated_probs[:top_k]

            new_preds = preds[sort_inds[:top_k]]
            generated_seqs = generated_seqs[sort_inds[:top_k]]
            gen_markers = gen_markers[sort_inds[:top_k]]
            active = active[sort_inds[:top_k]]
            orig_markers = orig_markers[sort_inds[:top_k]]

            candidate_maps = [
                candidate_maps[ind.item()] for ind in sort_inds[:top_k]
            ]

            active = active * (new_preds != end_token).long()
            gen_markers += active

            if torch.sum(active) == 0: break

        # gen_markers are pointing at the end_token. Move them one ahead
        gen_markers += 1

        assert generated_seqs.shape[0] == top_k

        for i in range(top_k):
            candidate_maps[i][curr_var] = generated_seqs[i][
                orig_markers[i]:gen_markers[i]].cpu().tolist()

        return generated_seqs, gen_markers, prior_probs, candidate_maps
class SequentialEncoder(Encoder):
    def __init__(self, config):
        super().__init__()

        self.vocab = vocab  = Vocab.load(config['vocab_file'])
        self.src_word_embed = nn.Embedding(len(vocab.source_tokens), config['source_embedding_size'])
        self.config = config

        self.decoder_cell_init = nn.Linear(config['source_encoding_size'], config['decoder_hidden_size'])

        if self.config['transformer'] == 'none':
            dropout = config['dropout']
            self.lstm_encoder = nn.LSTM(input_size=self.src_word_embed.embedding_dim,
                                        hidden_size=config['source_encoding_size'] // 2, num_layers=config['num_layers'],
                                        batch_first=True, bidirectional=True, dropout=dropout)

            self.dropout = nn.Dropout(dropout)

        elif self.config['transformer'] == 'bert':
            self.vocab_size = len(self.vocab.source_tokens) + 1

            state_dict = torch.load('saved_checkpoints/bert_2604/bert_pretrained_epoch_23_batch_140000.pth')

            keys_to_delete = ["cls.predictions.bias", "cls.predictions.transform.dense.weight", "cls.predictions.transform.dense.bias", "cls.predictions.transform.LayerNorm.weight",
                            "cls.predictions.transform.LayerNorm.bias", "cls.predictions.decoder.weight", "cls.predictions.decoder.bias",
                            "cls.seq_relationship.weight", "cls.seq_relationship.bias"]

            from collections import OrderedDict
            new_state_dict = OrderedDict()
            for k, v in state_dict['model'].items():
                if k in keys_to_delete: continue
                name = k[5:] # remove `bert.`
                new_state_dict[name] = v

            bert_config = BertConfig(vocab_size=self.vocab_size, max_position_embeddings=512, num_hidden_layers=6, hidden_size=256, num_attention_heads=4)
            self.bert_model = BertModel(bert_config)
            self.bert_model.load_state_dict(new_state_dict)

        elif self.config['transformer'] == 'xlnet':
            self.vocab_size = len(self.vocab.source_tokens) + 1

            state_dict = torch.load('saved_checkpoints/xlnet_2704/xlnet1_pretrained_epoch_13_iter_500000.pth')

            keys_to_delete = ["lm_loss.weight", "lm_loss.bias"]

            from collections import OrderedDict
            new_state_dict = OrderedDict()
            for k, v in state_dict['model'].items():
                if k in keys_to_delete: continue
                if k[:12] == 'transformer.': name = k[12:]
                else:                       name = k
                new_state_dict[name] = v

            xlnet_config = XLNetConfig(vocab_size=self.vocab_size, d_model=256, n_layer=12)
            self.xlnet_model = XLNetModel(xlnet_config)
            self.xlnet_model.load_state_dict(new_state_dict)
        else:
            print("Error! Unknown transformer type '{}'".format(self.config['transformer']))

    @property
    def device(self):
        return self.src_word_embed.weight.device

    @classmethod
    def default_params(cls):
        return {
            'source_encoding_size': 256,
            'decoder_hidden_size': 128,
            'source_embedding_size': 128,
            'vocab_file': None,
            'num_layers': 1
        }

    @classmethod
    def build(cls, config):
        params = util.update(SequentialEncoder.default_params(), config)

        return cls(params)

    def forward(self, tensor_dict: Dict[str, torch.Tensor]):
        if self.config['transformer'] == 'bert':
            code_token_encoding, code_token_mask = self.encode_bert(tensor_dict['src_code_tokens'])
        elif self.config['transformer'] == 'xlnet':
            code_token_encoding, code_token_mask = self.encode_xlnet(tensor_dict['src_code_tokens'])
        elif self.config['transformer'] == 'none':
            code_token_encoding, code_token_mask, (last_states, last_cells) = self.encode_sequence(tensor_dict['src_code_tokens'])
        else:
            print("Error! Unknown transformer type '{}'".format(self.config['transformer']))
        # (batch_size, max_variable_mention_num)
        # variable_mention_positions = tensor_dict['variable_position']
        variable_mention_mask = tensor_dict['variable_mention_mask']
        variable_mention_to_variable_id = tensor_dict['variable_mention_to_variable_id']

        # (batch_size, max_variable_num)
        variable_encoding_mask = tensor_dict['variable_encoding_mask']
        variable_mention_num = tensor_dict['variable_mention_num']

        # # (batch_size, max_variable_mention_num, encoding_size)
        # variable_mention_encoding = torch.gather(code_token_encoding, 1, variable_mention_positions.unsqueeze(-1).expand(-1, -1, code_token_encoding.size(-1))) * variable_mention_positions_mask
        max_time_step = variable_mention_to_variable_id.size(1)
        variable_num = variable_mention_num.size(1)
        encoding_size = code_token_encoding.size(-1)

        variable_mention_encoding = code_token_encoding * variable_mention_mask.unsqueeze(-1)
        variable_encoding = torch.zeros(tensor_dict['batch_size'], variable_num, encoding_size, device=self.device)
        variable_encoding.scatter_add_(1,
                                       variable_mention_to_variable_id.unsqueeze(-1).expand(-1, -1, encoding_size),
                                       variable_mention_encoding) * variable_encoding_mask.unsqueeze(-1)
        variable_encoding = variable_encoding / (variable_mention_num + (1. - variable_encoding_mask) * nn_util.SMALL_NUMBER).unsqueeze(-1)

        if self.config['transformer'] == 'bert' or self.config['transformer'] == 'xlnet':
            context_encoding = dict(
                variable_encoding=variable_encoding,
                code_token_encoding=code_token_encoding,
                code_token_mask=code_token_mask
            )
        else:
            context_encoding = dict(
                variable_encoding=variable_encoding,
                code_token_encoding=code_token_encoding,
                code_token_mask=code_token_mask,
                last_states=last_states,
                last_cells=last_cells
            )

        context_encoding.update(tensor_dict)

        return context_encoding

    def encode_xlnet(self, input_ids):

        attention_mask = torch.ones_like(input_ids).float()
        attention_mask[input_ids == PAD_ID] = 0.0

        assert torch.max(input_ids) < self.vocab_size
        assert torch.min(input_ids) >= 0

        if torch.cuda.is_available():
            input_ids       = input_ids.cuda()
            attention_mask  = attention_mask.cuda()

        outputs = self.xlnet_model(input_ids=input_ids, attention_mask=attention_mask)

        return outputs[0], attention_mask

    def encode_bert(self, input_ids):

        attention_mask = torch.ones_like(input_ids).float()
        attention_mask[input_ids == PAD_ID] = 0.0

        assert torch.max(input_ids) < self.vocab_size
        assert torch.min(input_ids) >= 0

        if torch.cuda.is_available():
            input_ids       = input_ids.cuda()
            attention_mask  = attention_mask.cuda()

        outputs = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)

        return outputs[0], attention_mask

    def encode_sequence(self, code_sequence):
        # (batch_size, max_code_length)
        # code_sequence = tensor_dict['src_code_tokens']

        # (batch_size, max_code_length, embed_size)
        code_token_embedding = self.src_word_embed(code_sequence)

        # (batch_size, max_code_length)
        code_token_mask = torch.ne(code_sequence, PAD_ID).float()
        # (batch_size)
        code_sequence_length = code_token_mask.sum(dim=-1).long()

        sorted_seqs, sorted_seq_lens, restoration_indices, sorting_indices = nn_util.sort_batch_by_length(code_token_embedding,
                                                                                                          code_sequence_length)

        packed_question_embedding = pack_padded_sequence(sorted_seqs, sorted_seq_lens.data.tolist(), batch_first=True)

        sorted_encodings, (last_states, last_cells) = self.lstm_encoder(packed_question_embedding)
        sorted_encodings, _ = pad_packed_sequence(sorted_encodings, batch_first=True)

        # apply dropout to the last layer
        # (batch_size, seq_len, hidden_size * 2)
        sorted_encodings = self.dropout(sorted_encodings)

        # (batch_size, question_len, hidden_size * 2)
        restored_encodings = sorted_encodings.index_select(dim=0, index=restoration_indices)

        # (num_layers, direction_num, batch_size, hidden_size)
        last_states = last_states.view(self.lstm_encoder.num_layers, 2, -1, self.lstm_encoder.hidden_size)
        last_states = last_states.index_select(dim=2, index=restoration_indices)
        last_cells = last_cells.view(self.lstm_encoder.num_layers, 2, -1, self.lstm_encoder.hidden_size)
        last_cells = last_cells.index_select(dim=2, index=restoration_indices)

        return restored_encodings, code_token_mask, (last_states, last_cells)

    @classmethod
    def to_tensor_dict(cls, examples: List[Example], next_examples=None, flips=None) -> Dict[str, torch.Tensor]:
        if next_examples is not None:
            max_time_step = max(e.source_seq_length + n.source_seq_length for e,n in zip(examples, next_examples))
        else:
            max_time_step = max(e.source_seq_length for e in examples)

        input = np.zeros((len(examples), max_time_step), dtype=np.int64)

        if next_examples is not None:
            seq_mask = torch.zeros((len(examples), max_time_step), dtype=torch.long)
        else:
            seq_mask = None

        variable_mention_to_variable_id = torch.zeros(len(examples), max_time_step, dtype=torch.long)
        variable_mention_mask = torch.zeros(len(examples), max_time_step)
        variable_mention_num = torch.zeros(len(examples), max(len(e.ast.variables) for e in examples))
        variable_encoding_mask = torch.zeros(variable_mention_num.size())

        for e_id, example in enumerate(examples):
            sub_tokens = example.sub_tokens
            input[e_id, :len(sub_tokens)] = example.sub_token_ids

            if next_examples is not None:
                next_example = next_examples[e_id]
                next_tokens = next_example.sub_tokens
                input[e_id, len(sub_tokens):len(sub_tokens)+len(next_tokens)] = next_example.sub_token_ids
                seq_mask[e_id, len(sub_tokens):] = 1
                # seq_mask[e_id, len(sub_tokens):len(sub_tokens)+len(next_tokens)] = 1

            variable_position_map = dict()
            var_name_to_id = {name: i for i, name in enumerate(example.ast.variables)}
            for i, sub_token in enumerate(sub_tokens):
                if sub_token.startswith('@@') and sub_token.endswith('@@'):
                    old_var_name = sub_token[2: -2]
                    if old_var_name in var_name_to_id:  # sometimes there are strings like `@@@@`
                        var_id = var_name_to_id[old_var_name]

                        variable_mention_to_variable_id[e_id, i] = var_id
                        variable_mention_mask[e_id, i] = 1.
                        variable_position_map.setdefault(old_var_name, []).append(i)

            for var_id, var_name in enumerate(example.ast.variables):
                try:
                    var_pos = variable_position_map[var_name]
                    variable_mention_num[e_id, var_id] = len(var_pos)
                except KeyError:
                    variable_mention_num[e_id, var_id] = 1
                    print(example.binary_file, f'variable [{var_name}] not found', file=sys.stderr)

            variable_encoding_mask[e_id, :len(example.ast.variables)] = 1.

        batch_dict =  dict(src_code_tokens=torch.from_numpy(input),
                            variable_mention_to_variable_id=variable_mention_to_variable_id,
                            variable_mention_mask=variable_mention_mask,
                            variable_mention_num=variable_mention_num,
                            variable_encoding_mask=variable_encoding_mask,
                            batch_size=len(examples))

        if next_examples is not None:
            batch_dict['next_seq_mask'] = seq_mask,
            batch_dict['next_sentence_label'] = torch.LongTensor(flips)

        return batch_dict

    def get_decoder_init_state(self, context_encoder, config=None):
        if 'last_cells' not in context_encoder:
            if self.config['init_decoder']:
                dec_init_cell = self.decoder_cell_init(torch.mean(context_encoder['code_token_encoding'], dim=1))
                dec_init_state = torch.tanh(dec_init_cell)
            else:
                dec_init_cell = dec_init_state = None

        elif 'last_cells' in context_encoder:
            fwd_last_layer_cell = context_encoder['last_cells'][-1, 0]
            bak_last_layer_cell = context_encoder['last_cells'][-1, 1]

            dec_init_cell = self.decoder_cell_init(torch.cat([fwd_last_layer_cell, bak_last_layer_cell], dim=-1))
            dec_init_state = torch.tanh(dec_init_cell)

        return dec_init_state, dec_init_cell

    def get_attention_memory(self, context_encoding, att_target='terminal_nodes'):
        assert att_target == 'terminal_nodes'

        memory = context_encoding['code_token_encoding']
        mask = context_encoding['code_token_mask']

        return memory, mask
Beispiel #13
0
def load_bert(bert_path, device):
    bert_config_path = os.path.join(bert_path, 'config.json')
    bert = BertModel(BertConfig(**load_json(bert_config_path))).to(device)
    bert_model_path = os.path.join(bert_path, 'model.bin')
    bert.load_state_dict(clean_state_dict(torch.load(bert_model_path)))
    return bert
Beispiel #14
0
def main():
    parser = argparse.ArgumentParser()

    # 1. 训练和测试数据路径
    parser.add_argument("--data_dir",
                        default='./data/cluener',
                        type=str,
                        help="Path to data.")
    parser.add_argument("--type_description",
                        default='./data/cluener/type_des.json',
                        type=str,
                        help="Path to data.")

    # 2. 预训练模型路径
    parser.add_argument("--vocab_file",
                        default="./data/pretrain/vocab.txt",
                        type=str,
                        help="Init vocab to resume training from.")
    parser.add_argument("--config_path",
                        default="./data/pretrain/config.json",
                        type=str,
                        help="Init config to resume training from.")
    parser.add_argument("--init_checkpoint",
                        default="./data/pretrain/pytorch_model.bin",
                        type=str,
                        help="Init checkpoint to resume training from.")

    # 3. 保存模型
    parser.add_argument("--save_path",
                        default="./check_points/",
                        type=str,
                        help="Path to save checkpoints.")
    parser.add_argument("--load_path",
                        default=None,
                        type=str,
                        help="Path to load checkpoints.")

    # 训练和测试参数
    parser.add_argument("--do_train",
                        default=True,
                        type=bool,
                        help="Whether to perform training.")
    parser.add_argument("--do_eval",
                        default=True,
                        type=bool,
                        help="Whether to perform evaluation on test data set.")
    parser.add_argument("--do_predict",
                        default=False,
                        type=bool,
                        help="Whether to perform evaluation on test data set.")
    parser.add_argument("--do_adv", default=True, type=bool)

    parser.add_argument("--epochs",
                        default=10,
                        type=int,
                        help="Number of epoches for fine-tuning.")
    parser.add_argument("--train_batch_size",
                        default=8,
                        type=int,
                        help="Total examples' number in batch for training.")
    parser.add_argument("--eval_batch_size",
                        default=1,
                        type=int,
                        help="Total examples' number in batch for eval.")
    parser.add_argument("--max_seq_len",
                        default=300,
                        type=int,
                        help="Number of words of the longest seqence.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="Learning rate used to train with warmup.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.01,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10% of training.")

    parser.add_argument("--use_cuda",
                        type=bool,
                        default=True,
                        help="whether to use cuda")
    parser.add_argument("--log_steps",
                        type=int,
                        default=20,
                        help="The steps interval to print loss.")
    parser.add_argument("--eval_step",
                        type=int,
                        default=1000,
                        help="The steps interval to print loss.")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")

    args = parser.parse_args()

    if args.use_cuda:
        device = torch.device("cuda")
        n_gpu = torch.cuda.device_count()
    else:
        device = torch.device("cpu")
        n_gpu = 0
    logger.info("device: {}, n_gpu: {}".format(device, n_gpu))

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)

    model_path_postfix = ''
    if args.do_adv:
        model_path_postfix += '_adv'

    args.save_path = os.path.join(args.save_path, 'ner' + model_path_postfix)

    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)

    bert_tokenizer = util.CNerTokenizer.from_pretrained(args.vocab_file)
    bert_config = BertConfig.from_pretrained(args.config_path)

    type2description = json.load(open(args.type_description))

    # 获取数据
    train_dataset = None
    eval_dataset = None
    if args.do_train:
        logger.info("loading train dataset")
        train_dataset = data_helper.NER_dataset(
            os.path.join(args.data_dir, 'train.json'), bert_tokenizer,
            args.max_seq_len, type2description)

    if args.do_eval:
        logger.info("loading eval dataset")
        eval_dataset = data_helper.NER_dataset(os.path.join(
            args.data_dir, 'dev.json'),
                                               bert_tokenizer,
                                               args.max_seq_len,
                                               type2description,
                                               shuffle=False)

    if args.do_predict:
        logger.info("loading test dataset")
        test_dataset = data_helper.NER_dataset(os.path.join(
            args.data_dir, 'test.json'),
                                               bert_tokenizer,
                                               args.max_seq_len,
                                               type2description,
                                               shuffle=False)

    if args.do_train:
        logging.info("Start training !")
        train_helper.train(bert_tokenizer, bert_config, args, train_dataset,
                           eval_dataset)

    if not args.do_train and args.do_eval:
        logging.info("Start evaluating !")
        bert_model = BertModel(config=bert_config)
        span_model = span_type.EntitySpan(config=bert_config)

        state = torch.load(args.load_path)
        bert_model.load_state_dict(state['bert_state_dict'])
        span_model.load_state_dict(state['span_state_dict'])
        logging.info("Checkpoint: %s have been loaded!" % (args.load_path))

        if args.use_cuda:
            bert_model.cuda()
            span_model.cuda()
        model_list = [bert_model, span_model]
        train_helper.evaluate(args, eval_dataset, model_list)

    if args.do_predict:
        logging.info("Start predicting !")
        bert_model = BertModel(config=bert_config)
        span_model = span_type.EntitySpan(config=bert_config)

        state = torch.load(args.load_path)
        bert_model.load_state_dict(state['bert_state_dict'])
        span_model.load_state_dict(state['span_state_dict'])
        logging.info("Checkpoint: %s have been loaded!" % (args.load_path))

        if args.use_cuda:
            bert_model.cuda()
            span_model.cuda()

        model_list = [bert_model, span_model]
        predict_res = train_helper.predict(args, test_dataset, model_list)
Beispiel #15
0
    'hidden_act': 'gelu',
    'hidden_dropout_prob': 0.1,
    'hidden_size': 768,
    'initializer_range': 0.02,
    'intermediate_size': 3072,
    'max_position_embeddings': 512,
    'num_attention_heads': 12,
    'num_hidden_layers': 12,
    'type_vocab_size': 2,
    'vocab_size': 8002
}

if __name__ == "__main__":
    ctx = "cpu"
    # kobert
    kobert_model_file = "./kobert_resources/pytorch_kobert_2439f391a6.params"
    kobert_vocab_file = "./kobert_resources/kobert_news_wiki_ko_cased-ae5711deb3.spiece"

    bertmodel = BertModel(config=BertConfig.from_dict(bert_config))
    bertmodel.load_state_dict(torch.load(kobert_model_file))
    device = torch.device(ctx)
    bertmodel.to(device)
    # bertmodel.eval()

    # for name, param in bertmodel.named_parameters():
    #     print(name, param.shape)

    for name, param in bertmodel.named_parameters():
        if param.requires_grad:
            print(name, param.shape)
Beispiel #16
0
class NERPredict(IPredict):
    '''
    构造函数, 初始化预测器
    use_gpu: 使用GPU
    bert_config_file_name: Bert模型配置文件路径
    vocab_file_name: 单词表文件路径
    tags_file_name: Tag表文件路径
    bert_model_path: Bert模型装载路径
    lstm_crf_model_path: CRF模型装载路径
    hidden_dim: CRF隐藏层
    '''
    def __init__(self, use_gpu, bert_config_file_name, vocab_file_name,
                 tags_file_name, bert_model_path, lstm_crf_model_path,
                 hidden_dim):
        self.use_gpu = use_gpu
        self.data_manager_init(vocab_file_name, tags_file_name)
        self.tokenizer = BertTokenizer.from_pretrained(vocab_file_name)
        self.model_init(hidden_dim, bert_config_file_name, bert_model_path,
                        lstm_crf_model_path)

    def data_manager_init(self, vocab_file_name, tags_file_name):
        tags_list = BERTDataManager.ReadTagsList(tags_file_name)
        tags_list = [tags_list]
        self.dm = BERTDataManager(tags_list=tags_list,
                                  vocab_file_name=vocab_file_name)

    def model_init(self, hidden_dim, bert_config_file_name, bert_model_path,
                   lstm_crf_model_path):
        config = BertConfig.from_json_file(bert_config_file_name)

        self.model = BertModel(config)

        bert_dict = torch.load(bert_model_path).module.state_dict()

        self.model.load_state_dict(bert_dict)
        self.birnncrf = torch.load(lstm_crf_model_path)

        self.model.eval()
        self.birnncrf.eval()

    def data_process(self, sentences):
        result = []
        pad_tag = '[PAD]'
        if type(sentences) == str:
            sentences = [sentences]
        max_len = 0
        for sentence in sentences:
            encode = self.tokenizer.encode(sentence, add_special_tokens=True)
            result.append(encode)
            if max_len < len(encode):
                max_len = len(encode)

        for i, sentence in enumerate(result):
            remain = max_len - len(sentence)
            for _ in range(remain):
                result[i].append(self.dm.wordToIdx(pad_tag))
        return torch.tensor(result)

    def pred(self, sentences):
        sentences = self.data_process(sentences)

        if torch.cuda.is_available() and self.use_gpu:
            self.model.cuda()
            self.birnncrf.cuda()
            sentences = sentences.cuda()

        outputs = self.model(input_ids=sentences,
                             attention_mask=sentences.gt(0))
        hidden_states = outputs[0]
        scores, tags = self.birnncrf(hidden_states, sentences.gt(0))
        final_tags = []
        decode_sentences = []

        for item in tags:
            final_tags.append([self.dm.idx_to_tag[tag] for tag in item])

        for item in sentences.tolist():
            decode_sentences.append(self.tokenizer.decode(item))

        return (scores, tags, final_tags, decode_sentences)

    def __call__(self, sentences):
        return self.pred(sentences)
Beispiel #17
0
    def __init__(self):
        self.src_tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
        self.tgt_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        self.tgt_tokenizer.bos_token = '<s>'
        self.tgt_tokenizer.eos_token = '</s>'

        #hidden_size and intermediate_size are both wrt all the attention heads. 
        #Should be divisible by num_attention_heads
        encoder_config = BertConfig(vocab_size=self.src_tokenizer.vocab_size,
                                    hidden_size=config.hidden_size,
                                    num_hidden_layers=config.num_hidden_layers,
                                    num_attention_heads=config.num_attention_heads,
                                    intermediate_size=config.intermediate_size,
                                    hidden_act=config.hidden_act,
                                    hidden_dropout_prob=config.dropout_prob,
                                    attention_probs_dropout_prob=config.dropout_prob,
                                    max_position_embeddings=512,
                                    type_vocab_size=2,
                                    initializer_range=0.02,
                                    layer_norm_eps=1e-12)

        decoder_config = BertConfig(vocab_size=self.tgt_tokenizer.vocab_size,
                                    hidden_size=config.hidden_size,
                                    num_hidden_layers=config.num_hidden_layers,
                                    num_attention_heads=config.num_attention_heads,
                                    intermediate_size=config.intermediate_size,
                                    hidden_act=config.hidden_act,
                                    hidden_dropout_prob=config.dropout_prob,
                                    attention_probs_dropout_prob=config.dropout_prob,
                                    max_position_embeddings=512,
                                    type_vocab_size=2,
                                    initializer_range=0.02,
                                    layer_norm_eps=1e-12,
                                    is_decoder=True)

        #Create encoder and decoder embedding layers.
        encoder_embeddings = torch.nn.Embedding(self.src_tokenizer.vocab_size, config.hidden_size, padding_idx=self.src_tokenizer.pad_token_id)
        decoder_embeddings = torch.nn.Embedding(self.tgt_tokenizer.vocab_size, config.hidden_size, padding_idx=self.tgt_tokenizer.pad_token_id)

        encoder = BertModel(encoder_config)
        encoder.set_input_embeddings(encoder_embeddings.cpu())

        decoder = BertForMaskedLM(decoder_config)
        decoder.set_input_embeddings(decoder_embeddings.cpu())

        input_dirs = config.model_output_dirs

        suffix = "pytorch_model.bin"
        decoderPath = os.path.join(input_dirs['decoder'], suffix)
        encoderPath = os.path.join(input_dirs['encoder'], suffix)

        decoder_state_dict = torch.load(decoderPath)
        encoder_state_dict = torch.load(encoderPath)
        decoder.load_state_dict(decoder_state_dict)
        encoder.load_state_dict(encoder_state_dict)
        self.model = TranslationModel(encoder, decoder, None, None, self.tgt_tokenizer, config)
        self.model.cpu()


        #model.eval()
        self.model.encoder.eval()
        self.model.decoder.eval()
Beispiel #18
0
    # get paths for model weights store/load
    if args.exp_name is not None:
        args.checkpoint_dir = '%s/%s/%s' % (SAVE_DIR, args.dataset,
                                            args.exp_name)
    else:
        args.checkpoint_dir = '%s/%s/%s' % (SAVE_DIR, args.dataset,
                                            args.config_name)

    print('checkpoints dir : ', args.checkpoint_dir)
    if not os.path.isdir(args.checkpoint_dir):
        os.makedirs(args.checkpoint_dir)

    args.start_epoch = 0

    if args.resume:
        # for testing, one has to load models from certain path
        if args.iter != -1:
            resume_file = get_assigned_file(args.checkpoint_dir, args.iter)
        else:
            resume_file = get_resume_file(args.checkpoint_dir)
        if resume_file is not None:
            print('Resume file is: ', resume_file)
            tmp = torch.load(resume_file)
            start_epoch = tmp['epoch'] + 1
            projection.load_state_dict(tmp['projection'])
            model.load_state_dict(tmp['feature'])
        else:
            raise Exception('Resume file not found')

    train(data_loader, model, projection, args)