예제 #1
0
    def to_tensor_dict(
            self,
            examples: List[Example],
            return_prediction_target=True) -> Dict[str, torch.Tensor]:
        from model.sequential_encoder import SequentialEncoder
        from model.graph_encoder import GraphASTEncoder

        if not hasattr(examples[0], 'target_prediction_seq_length'):
            for example in examples:
                self.annotate_example(example)

        if self.config['encoder']['type'] == 'GraphASTEncoder':
            init_with_seq_encoding = self.config['encoder'][
                'init_with_seq_encoding']
            packed_graph, tensor_dict = GraphASTEncoder.to_packed_graph(
                [e.ast for e in examples],
                connections=self.config['encoder']['connections'],
                init_with_seq_encoding=init_with_seq_encoding)

            if init_with_seq_encoding:
                seq_tensor_dict = SequentialEncoder.to_tensor_dict(examples)
                tensor_dict['seq_encoder_input'] = seq_tensor_dict

            _tensors = GraphASTEncoder.to_tensor_dict(packed_graph,
                                                      self.grammar, self.vocab)
            tensor_dict.update(_tensors)
        elif self.config['encoder']['type'] == 'SequentialEncoder':
            tensor_dict = SequentialEncoder.to_tensor_dict(examples)
        elif self.config['encoder']['type'] == 'HybridEncoder':
            packed_graph, gnn_tensor_dict = GraphASTEncoder.to_packed_graph(
                [e.ast for e in examples],
                connections=self.config['encoder']['graph_encoder']
                ['connections'])
            gnn_tensors = GraphASTEncoder.to_tensor_dict(
                packed_graph, self.grammar, self.vocab)
            gnn_tensor_dict.update(gnn_tensors)

            seq_tensor_dict = SequentialEncoder.to_tensor_dict(examples)

            tensor_dict = {
                'graph_encoder_input': gnn_tensor_dict,
                'seq_encoder_input': seq_tensor_dict
            }
        else:
            raise ValueError('UnknownEncoderType')

        if self.train or return_prediction_target:
            prediction_target = self.to_batched_prediction_target(examples)
            tensor_dict['prediction_target'] = prediction_target

        if not self.train:
            if hasattr(examples[0], 'test_meta'):
                tensor_dict['test_meta'] = [e.test_meta for e in examples]

        tensor_dict['batch_size'] = len(examples)
        num_elements = nn_util.get_tensor_dict_size(tensor_dict)
        tensor_dict['num_elements'] = num_elements

        return tensor_dict
예제 #2
0
    def __init__(self, config):
        super(HybridEncoder, self).__init__()

        self.graph_encoder = GraphASTEncoder.build(config['graph_encoder'])
        self.seq_encoder = SequentialEncoder.build(config['seq_encoder'])

        self.hybrid_method = config['hybrid_method']
        if self.hybrid_method == 'linear_proj':
            self.projection = nn.Linear(config['seq_encoder']['decoder_hidden_size'] + config['graph_encoder']['gnn']['hidden_size'],
                                        config['source_encoding_size'], bias=False)
        else:
            assert self.hybrid_method == 'concat'

        self.config = config
예제 #3
0
 def default_params(cls):
     return {
         "graph_encoder": GraphASTEncoder.default_params(),
         "seq_encoder": SequentialEncoder.default_params(),
         "hybrid_method": "linear_proj"
     }
예제 #4
0
    def __init__(self, vocab, top_k, config, device):
        super(RenamingModelHybrid, self).__init__()

        self.vocab = vocab
        self.top_k = top_k
        self.source_vocab_size = len(self.vocab.source_tokens) + 1

        self.graph_encoder = GraphASTEncoder.build(
            config['encoder']['graph_encoder'])
        self.graph_emb_size = config['encoder']['graph_encoder']['gnn'][
            'hidden_size']
        self.emb_size = 256

        state_dict = torch.load(
            'saved_checkpoints/bert_2604/bert_pretrained_epoch_23_batch_140000.pth',
            map_location=device)

        keys_to_delete = [
            "cls.predictions.bias", "cls.predictions.transform.dense.weight",
            "cls.predictions.transform.dense.bias",
            "cls.predictions.transform.LayerNorm.weight",
            "cls.predictions.transform.LayerNorm.bias",
            "cls.predictions.decoder.weight", "cls.predictions.decoder.bias",
            "cls.seq_relationship.weight", "cls.seq_relationship.bias"
        ]

        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict['model'].items():
            if k in keys_to_delete: continue
            name = k[5:]  # remove `bert.`
            new_state_dict[name] = v

        bert_config = BertConfig(vocab_size=self.source_vocab_size,
                                 max_position_embeddings=512,
                                 num_hidden_layers=6,
                                 hidden_size=self.emb_size,
                                 num_attention_heads=4)
        self.bert_encoder = BertModel(bert_config)
        self.bert_encoder.load_state_dict(new_state_dict)

        self.target_vocab_size = len(self.vocab.all_subtokens) + 1

        bert_config = BertConfig(vocab_size=self.target_vocab_size,
                                 max_position_embeddings=1000,
                                 num_hidden_layers=6,
                                 hidden_size=self.emb_size,
                                 num_attention_heads=4,
                                 is_decoder=True)
        self.bert_decoder = BertModel(bert_config)

        state_dict = torch.load(
            'saved_checkpoints/bert_0905/bert_decoder_epoch_19_batch_220000.pth',
            map_location=device)

        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict['model'].items():
            if k in keys_to_delete: continue
            if 'crossattention' in k: continue
            name = k[5:]  # remove `bert.`
            new_state_dict[name] = v

        for key in new_state_dict:
            self.bert_decoder.state_dict()[key].copy_(new_state_dict[key])

        self.enc_graph_map = nn.Linear(self.emb_size + self.graph_emb_size,
                                       self.emb_size)
        self.fc_final = nn.Linear(self.emb_size, self.target_vocab_size)

        self.fc_final.weight.data = state_dict['model'][
            'cls.predictions.decoder.weight']