def to_tensor_dict( self, examples: List[Example], return_prediction_target=True) -> Dict[str, torch.Tensor]: from model.sequential_encoder import SequentialEncoder from model.graph_encoder import GraphASTEncoder if not hasattr(examples[0], 'target_prediction_seq_length'): for example in examples: self.annotate_example(example) if self.config['encoder']['type'] == 'GraphASTEncoder': init_with_seq_encoding = self.config['encoder'][ 'init_with_seq_encoding'] packed_graph, tensor_dict = GraphASTEncoder.to_packed_graph( [e.ast for e in examples], connections=self.config['encoder']['connections'], init_with_seq_encoding=init_with_seq_encoding) if init_with_seq_encoding: seq_tensor_dict = SequentialEncoder.to_tensor_dict(examples) tensor_dict['seq_encoder_input'] = seq_tensor_dict _tensors = GraphASTEncoder.to_tensor_dict(packed_graph, self.grammar, self.vocab) tensor_dict.update(_tensors) elif self.config['encoder']['type'] == 'SequentialEncoder': tensor_dict = SequentialEncoder.to_tensor_dict(examples) elif self.config['encoder']['type'] == 'HybridEncoder': packed_graph, gnn_tensor_dict = GraphASTEncoder.to_packed_graph( [e.ast for e in examples], connections=self.config['encoder']['graph_encoder'] ['connections']) gnn_tensors = GraphASTEncoder.to_tensor_dict( packed_graph, self.grammar, self.vocab) gnn_tensor_dict.update(gnn_tensors) seq_tensor_dict = SequentialEncoder.to_tensor_dict(examples) tensor_dict = { 'graph_encoder_input': gnn_tensor_dict, 'seq_encoder_input': seq_tensor_dict } else: raise ValueError('UnknownEncoderType') if self.train or return_prediction_target: prediction_target = self.to_batched_prediction_target(examples) tensor_dict['prediction_target'] = prediction_target if not self.train: if hasattr(examples[0], 'test_meta'): tensor_dict['test_meta'] = [e.test_meta for e in examples] tensor_dict['batch_size'] = len(examples) num_elements = nn_util.get_tensor_dict_size(tensor_dict) tensor_dict['num_elements'] = num_elements return tensor_dict
def __init__(self, config): super(HybridEncoder, self).__init__() self.graph_encoder = GraphASTEncoder.build(config['graph_encoder']) self.seq_encoder = SequentialEncoder.build(config['seq_encoder']) self.hybrid_method = config['hybrid_method'] if self.hybrid_method == 'linear_proj': self.projection = nn.Linear(config['seq_encoder']['decoder_hidden_size'] + config['graph_encoder']['gnn']['hidden_size'], config['source_encoding_size'], bias=False) else: assert self.hybrid_method == 'concat' self.config = config
def default_params(cls): return { "graph_encoder": GraphASTEncoder.default_params(), "seq_encoder": SequentialEncoder.default_params(), "hybrid_method": "linear_proj" }
def __init__(self, vocab, top_k, config, device): super(RenamingModelHybrid, self).__init__() self.vocab = vocab self.top_k = top_k self.source_vocab_size = len(self.vocab.source_tokens) + 1 self.graph_encoder = GraphASTEncoder.build( config['encoder']['graph_encoder']) self.graph_emb_size = config['encoder']['graph_encoder']['gnn'][ 'hidden_size'] self.emb_size = 256 state_dict = torch.load( 'saved_checkpoints/bert_2604/bert_pretrained_epoch_23_batch_140000.pth', map_location=device) keys_to_delete = [ "cls.predictions.bias", "cls.predictions.transform.dense.weight", "cls.predictions.transform.dense.bias", "cls.predictions.transform.LayerNorm.weight", "cls.predictions.transform.LayerNorm.bias", "cls.predictions.decoder.weight", "cls.predictions.decoder.bias", "cls.seq_relationship.weight", "cls.seq_relationship.bias" ] from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict['model'].items(): if k in keys_to_delete: continue name = k[5:] # remove `bert.` new_state_dict[name] = v bert_config = BertConfig(vocab_size=self.source_vocab_size, max_position_embeddings=512, num_hidden_layers=6, hidden_size=self.emb_size, num_attention_heads=4) self.bert_encoder = BertModel(bert_config) self.bert_encoder.load_state_dict(new_state_dict) self.target_vocab_size = len(self.vocab.all_subtokens) + 1 bert_config = BertConfig(vocab_size=self.target_vocab_size, max_position_embeddings=1000, num_hidden_layers=6, hidden_size=self.emb_size, num_attention_heads=4, is_decoder=True) self.bert_decoder = BertModel(bert_config) state_dict = torch.load( 'saved_checkpoints/bert_0905/bert_decoder_epoch_19_batch_220000.pth', map_location=device) from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict['model'].items(): if k in keys_to_delete: continue if 'crossattention' in k: continue name = k[5:] # remove `bert.` new_state_dict[name] = v for key in new_state_dict: self.bert_decoder.state_dict()[key].copy_(new_state_dict[key]) self.enc_graph_map = nn.Linear(self.emb_size + self.graph_emb_size, self.emb_size) self.fc_final = nn.Linear(self.emb_size, self.target_vocab_size) self.fc_final.weight.data = state_dict['model'][ 'cls.predictions.decoder.weight']