Exemple #1
0
    def load_model(self, run_id, snapshot_iteration, gpu=True):
        model_params = self.load_parameters(run_id, snapshot_iteration, gpu=gpu)
        config = self.load_config(run_id)
        model_config = self._prepare_model_config(config)

        language = config['data_setup']['language']
        data_manager = CTPreprocessedDataManager(DATA_PATH_STAGE_2, language)

        decoder_config = model_config['lm_decoder']

        word_vocab, token_type_vocab, node_type_vocab = data_manager.load_vocabularies()

        transformer_encoder_config = model_config['lm_encoder']
        transformer_encoder_config['num_node_types'] = len(node_type_vocab)
        transformer_encoder_config['vocab_size'] = len(word_vocab)
        transformer_encoder_config['transformer_config'] = GreatTransformerConfig(
            **transformer_encoder_config['transformer_config'])
        num_edge_types = 0
        for d in config['data_transforms']['relative_distances']:
            if d in ["ancestor_sp", "sibling_sp"]:
                num_edge_types += 2
            elif d == "shortest_paths":
                num_edge_types += 1
        transformer_encoder_config['transformer_config'].bias_dim = num_edge_types
        if ',' in data_manager.language:
            transformer_encoder_config['num_languages'] = len(data_manager.language.split(','))

        great_lm_encoder = GreatEncoderTransformerAdapter(GreatEncoderConfig(**transformer_encoder_config))

        decoder_config['sos_id'] = word_vocab[SOS_TOKEN]
        if 'num_subtokens_output' in config['data_setup']:
            decoder_config['output_subtokens_per_token'] = config['data_setup']['num_subtokens_output']
        else:
            decoder_config['output_subtokens_per_token'] = NUM_SUB_TOKENS

        if 'use_pointer_network' in config['data_setup']:
            decoder_config['use_pointer_network'] = config['data_setup']['use_pointer_network']

        decoder_config['lm_encoder'] = great_lm_encoder
        decoder_config['loss_fct'] = model_config['loss_fct']

        model = GreatTransformerDecoder(TransformerLMDecoderConfig(**decoder_config))

        try:
            model.load_state_dict(model_params)
        except RuntimeError:
            # In most cases, this is due to the legacy issue with encoder_self_attention
            model.add_module('encoder_self_attention',
                             MultiheadAttention(model.d_model, decoder_config['decoder_nhead'],
                                                dropout=decoder_config['decoder_dropout']))
            try:
                model.load_state_dict(model_params)
            except RuntimeError:
                decoder_config['concat_query_and_pointer'] = False
                model = GreatTransformerDecoder(TransformerLMDecoderConfig(**decoder_config))
                model.load_state_dict(model_params)

        return model
Exemple #2
0
    def test_transformer_decoder_language_modeling(self):
        n_predict = 2
        data_manager = CTPreprocessedDataManager(DATA_PATH_STAGE_2,
                                                 language='java-small',
                                                 partition='valid')
        dataset = CTLanguageModelingDatasetNoPunctuation(
            data_manager,
            use_pointer_network=True,
            num_labels_per_sample=n_predict)
        dataloader = DataLoader(dataset,
                                collate_fn=dataset.collate_fn,
                                batch_size=2)
        word_vocab, _, _ = data_manager.load_vocabularies()

        batch = next(iter(dataloader))
        batch.labels[:] = word_vocab[PAD_TOKEN]

        model = self._create_transformer_decoder_model(
            data_manager, output_subtokens_per_token=5)
        optimizer = optim.Adam(model.parameters(), lr=5e-3)

        # random labels that are always at same position in extended_vocabulary_ids => Pointer finds them
        tq = tqdm(range(100))
        for _ in tq:
            batch.labels[:, :, 0] = torch.randint(len(word_vocab),
                                                  size=(2, n_predict))
            batch.tokens[:, [0, 1], 0] = batch.labels[:, :, 0]
            batch.extended_vocabulary_ids[:, list(range(n_predict)
                                                  )] = batch.labels[:, :, 0]
            output = model.forward_batch(batch)
            output.loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            tq.set_postfix(loss=output.loss.item())

        self.assertTrue(not torch.isnan(output.logits).any())
        self.assertLess(output.loss.item(), 1)

        model = self._create_transformer_decoder_model(
            data_manager, output_subtokens_per_token=5)
        optimizer = optim.Adam(model.parameters(), lr=5e-3)

        # random labels that are NOT part of input => no chance
        tq = tqdm(range(100))
        for _ in tq:
            batch.labels[:, :, 0] = torch.randint(len(word_vocab),
                                                  size=(2, n_predict))
            batch.tokens[:, [0, 1], 0] = batch.labels[:, :, 0]
            output = model.forward_batch(batch)
            output.loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            tq.set_postfix(loss=output.loss.item())

        self.assertTrue(not torch.isnan(output.logits).any())
        self.assertGreater(output.loss.item(), 0.5)
    def test_great(self):
        data_manager = CTPreprocessedDataManager(
            DATA_PATH_STAGE_2,
            language='python,javascript,go,ruby',
            partition='train',
            infinite_loading=True)
        dataset = CTCodeSummarizationDatasetEdgeTypes(data_manager,
                                                      num_sub_tokens_output=6,
                                                      use_pointer_network=True)
        dataloader = DataLoader(dataset,
                                batch_size=2,
                                collate_fn=dataset.collate_fn)

        iterator = iter(dataloader)

        model = self._create_great_model(data_manager,
                                         use_pointer_network=True)
        optimizer = optim.Adam(model.parameters(), lr=8e-3)

        tq = tqdm(range(100))
        batch = next(iterator)
        for _ in tq:
            output = model.forward_batch(batch)
            output.loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            tq.set_postfix(loss=output.loss.item())

        self.assertLess(output.loss.item(), 1)
Exemple #4
0
    def test_xl_net(self):
        data_manager = CTPreprocessedDataManager(DATA_PATH_STAGE_2,
                                                 language='java-small',
                                                 partition='valid')
        dataset = CTCodeSummarizationDatasetNoPunctuation(
            data_manager, num_sub_tokens_output=6, use_pointer_network=True)
        dataloader = DataLoader(dataset,
                                batch_size=3,
                                collate_fn=dataset.collate_fn)

        iterator = iter(dataloader)
        batch = next(iterator)

        model = self._create_xl_net_model(data_manager)
        optimizer = optim.Adam(model.parameters(), lr=8e-3)

        tq = tqdm(range(100))
        for _ in tq:
            output = model.forward_batch(batch)
            output.loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            tq.set_postfix(loss=output.loss.item())

        self.assertLess(output.loss.item(), 1)
Exemple #5
0
    def load_model(self, run_id, snapshot_iteration, gpu=True):
        model_params = self.load_parameters(run_id, snapshot_iteration, gpu=gpu)
        config = self.load_config(run_id)
        model_config = self._prepare_model_config(config)

        language = config['data_setup']['language']
        data_manager = CTPreprocessedDataManager(DATA_PATH_STAGE_2, language)

        decoder_config = model_config['lm_decoder']

        word_vocab, token_type_vocab, node_type_vocab = data_manager.load_vocabularies()

        transformer_encoder_config = model_config['lm_encoder']
        transformer_encoder_config['num_token_types'] = len(token_type_vocab)
        transformer_encoder_config['vocab_size'] = len(word_vocab)

        decoder_config['sos_id'] = word_vocab[SOS_TOKEN]
        if 'num_subtokens_output' in config['data_setup']:
            decoder_config['output_subtokens_per_token'] = config['data_setup']['num_subtokens_output']
        else:
            decoder_config['output_subtokens_per_token'] = NUM_SUB_TOKENS

        if 'use_pointer_network' in config['data_setup']:
            decoder_config['use_pointer_network'] = config['data_setup']['use_pointer_network']

        decoder_config['lm_encoder'] = transformer_encoder_config
        decoder_config['loss_fct'] = model_config['loss_fct']

        model = XLNetTransformerDecoder(TransformerLMDecoderConfig(**decoder_config))

        try:
            model.load_state_dict(model_params)
        except RuntimeError:
            # In most cases, this is due to the legacy issue with encoder_self_attention
            model.add_module('encoder_self_attention',
                             MultiheadAttention(model.d_model, decoder_config['decoder_nhead'],
                                                dropout=decoder_config['decoder_dropout']))
            try:
                model.load_state_dict(model_params)
            except RuntimeError:
                decoder_config['concat_query_and_pointer'] = False
                model = CodeTransformerDecoder(TransformerLMDecoderConfig(**decoder_config))
                model.load_state_dict(model_params)

        return model
Exemple #6
0
 def __init__(self, data_manager: CTPreprocessedDataManager, token_distances=None, max_distance_mask=None,
              num_sub_tokens=5, num_labels_per_sample=5, min_sequence_length=5, max_num_tokens=MAX_NUM_TOKENS,
              use_pointer_network=False):
     super(CTLanguageModelingDatasetNoPunctuation, self).__init__(data_manager, token_distances=token_distances,
                                                                  max_distance_mask=max_distance_mask,
                                                                  num_sub_tokens=num_sub_tokens,
                                                                  num_labels_per_sample=num_labels_per_sample,
                                                                  max_num_tokens=None,
                                                                  use_pointer_network=use_pointer_network)
     self.config = data_manager.load_config()
     self.min_sequence_length = min_sequence_length
     self.max_num_tokens_no_punctuation = max_num_tokens
Exemple #7
0
def decode_tokens(tokens: torch.Tensor,
                  data_manager: CTPreprocessedDataManager = None,
                  word_vocab=None,
                  config=None) -> List[List[str]]:
    assert data_manager is not None or word_vocab is not None and config is not None, "Either data_manager or word_vocab and config have to be provided"
    if word_vocab is None:
        word_vocab, _, _ = data_manager.load_vocabularies()
    if config is None:
        config = data_manager.load_config()
    pad_id = config['preprocessing']['special_symbols'][PAD_TOKEN]

    words = []
    for token in tokens:
        if isinstance(token, list) or isinstance(token, torch.Tensor):
            words.append([
                word_vocab.reverse_lookup(sub_token.item())
                for sub_token in token if not sub_token == pad_id
            ])
        elif not token == pad_id:
            words.append(word_vocab.reverse_lookup(token))

    return words
 def __init__(self,
              data_manager: CTPreprocessedDataManager,
              token_distances=None,
              max_distance_mask=None,
              num_sub_tokens=5,
              num_sub_tokens_output=5,
              use_token_types=True,
              use_pointer_network=False,
              max_num_tokens=MAX_NUM_TOKENS):
     super(CTCodeSummarizationDatasetNoPunctuation,
           self).__init__(data_manager,
                          token_distances,
                          max_distance_mask,
                          num_sub_tokens,
                          num_sub_tokens_output,
                          use_token_types,
                          use_pointer_network=use_pointer_network,
                          max_num_tokens=None)
     self.config = data_manager.load_config()
     self.max_num_tokens_no_punctuation = max_num_tokens
Exemple #9
0
 def __init__(self,
              data_manager: CTPreprocessedDataManager,
              token_distances=None,
              max_distance_mask=None,
              num_sub_tokens=5,
              num_sub_tokens_output=5,
              use_pointer_network=False,
              max_num_tokens=MAX_NUM_TOKENS,
              mask_all_tokens=False):
     super(CTCodeSummarizationOnlyASTDataset,
           self).__init__(data_manager,
                          token_distances,
                          max_distance_mask,
                          num_sub_tokens,
                          num_sub_tokens_output,
                          use_token_types=False,
                          use_pointer_network=use_pointer_network,
                          max_num_tokens=None)
     self.max_num_tokens_only_ast = max_num_tokens
     self.config = data_manager.load_config()
     self.mask_all_tokens = mask_all_tokens
def make_batch_from_sample(stage2_sample: CTStage2Sample, model_config,
                           model_type):
    assert isinstance(stage2_sample.token_mapping,
                      dict), f"Please re-generate the sample"
    data_manager = CTPreprocessedDataManager(
        DATA_PATH_STAGE_2,
        model_config['data_setup']['language'],
        partition='train',
        shuffle=True)

    # Setup dataset to generate batch as input for model
    LIMIT_TOKENS = 1000
    token_distances = None
    if TokenDistancesTransform.name in model_config['data_transforms'][
            'relative_distances']:
        num_bins = data_manager.load_config()['num_bins']
        distance_binning_config = model_config['data_transforms'][
            'distance_binning']
        if distance_binning_config['type'] == 'exponential':
            trans_func = ExponentialBinning(
                distance_binning_config['growth_factor'])
        else:
            trans_func = EqualBinning()
        token_distances = TokenDistancesTransform(
            DistanceBinning(num_bins, distance_binning_config['n_fixed_bins'],
                            trans_func))

    use_pointer_network = model_config['data_setup']['use_pointer_network']
    if model_type in {'great'}:
        dataset_type = 'great'
    elif 'use_only_ast' in model_config['data_setup'] and model_config[
            'data_setup']['use_only_ast']:
        dataset_type = 'only_ast'
    elif 'use_no_punctuation' in model_config['data_setup'] and model_config[
            'data_setup']['use_no_punctuation']:
        dataset_type = 'no_punctuation'
    else:
        dataset_type = 'regular'

    if dataset_type == 'great':
        dataset = CTCodeSummarizationDatasetEdgeTypes(
            data_manager,
            num_sub_tokens_output=NUM_SUB_TOKENS_METHOD_NAME,
            use_pointer_network=use_pointer_network,
            token_distances=token_distances,
            max_num_tokens=LIMIT_TOKENS)
    elif dataset_type == 'regular':
        dataset = CTCodeSummarizationDataset(
            data_manager,
            num_sub_tokens_output=NUM_SUB_TOKENS_METHOD_NAME,
            use_pointer_network=use_pointer_network,
            max_num_tokens=LIMIT_TOKENS,
            token_distances=token_distances)
    elif dataset_type == 'no_punctuation':
        dataset = CTCodeSummarizationDatasetNoPunctuation(
            data_manager,
            num_sub_tokens_output=NUM_SUB_TOKENS_METHOD_NAME,
            use_pointer_network=use_pointer_network,
            max_num_tokens=LIMIT_TOKENS,
            token_distances=token_distances)
    elif dataset_type == 'only_ast':
        dataset = CTCodeSummarizationOnlyASTDataset(
            data_manager,
            num_sub_tokens_output=NUM_SUB_TOKENS_METHOD_NAME,
            use_pointer_network=use_pointer_network,
            max_num_tokens=LIMIT_TOKENS,
            token_distances=token_distances)
    else:
        raise ValueError(f"Unknown dataset type: {dataset_type}")

    # Hijack dataset to only contain user specified code snippet
    dataset.dataset = (stage2_sample for _ in range(1))
    processed_sample = next(dataset)
    batch = dataset.collate_fn([processed_sample])

    return batch
def predict_method_name(model,
                        model_config,
                        code_snippet: str,
                        method_name_place_holder='f'):
    language = model_config['data_setup']['language']

    # Build data manager and load vocabularies + configs
    data_manager = CTPreprocessedDataManager(DATA_PATH_STAGE_2,
                                             language,
                                             partition='train',
                                             shuffle=True)
    vocabs = data_manager.load_vocabularies()
    if len(vocabs) == 4:
        method_name_vocab = vocabs[-1]
    else:
        method_name_vocab = vocabs[0]
    word_vocab = vocabs[0]
    data_config = data_manager.load_config()

    # Stage 1 Preprocessing (Compute AST)
    lexer_language = 'java' if language in {
        'java-small', 'java-medium', 'java-large', 'java-small-pretrain',
        'java-pretrain'
    } else language
    preprocessor = CTStage1Preprocessor(lexer_language,
                                        allow_empty_methods=True)
    stage1 = preprocessor.process(
        [(method_name_place_holder, "", code_snippet)], 0)

    # Stage 2 Preprocessing (Compute Distance Matrices)
    distances_config = data_config['distances']
    PPR_ALPHA = distances_config['ppr_alpha']
    PPR_USE_LOG = distances_config['ppr_use_log']
    PPR_THRESHOLD = distances_config['ppr_threshold']

    SP_THRESHOLD = distances_config['sp_threshold']

    ANCESTOR_SP_FORWARD = distances_config['ancestor_sp_forward']
    ANCESTOR_SP_BACKWARD = distances_config['ancestor_sp_backward']
    ANCESTOR_SP_NEGATIVE_REVERSE_DISTS = distances_config[
        'ancestor_sp_negative_reverse_dists']
    ANCESTOR_SP_THRESHOLD = distances_config['ancestor_sp_threshold']

    SIBLING_SP_FORWARD = distances_config['sibling_sp_forward']
    SIBLING_SP_BACKWARD = distances_config['sibling_sp_backward']
    SIBLING_SP_NEGATIVE_REVERSE_DISTS = distances_config[
        'sibling_sp_negative_reverse_dists']
    SIBLING_SP_THRESHOLD = distances_config['sibling_sp_threshold']

    binning_config = data_config['binning']
    EXPONENTIAL_BINNING_GROWTH_FACTOR = binning_config[
        'exponential_binning_growth_factor']
    N_FIXED_BINS = binning_config['n_fixed_bins']
    NUM_BINS = binning_config[
        'num_bins']  # How many bins should be calculated for the values in distance matrices

    preprocessing_config = data_config['preprocessing']
    REMOVE_PUNCTUATION = preprocessing_config['remove_punctuation']

    distance_metrics = [
        PersonalizedPageRank(threshold=PPR_THRESHOLD,
                             log=PPR_USE_LOG,
                             alpha=PPR_ALPHA),
        ShortestPaths(threshold=SP_THRESHOLD),
        AncestorShortestPaths(
            forward=ANCESTOR_SP_FORWARD,
            backward=ANCESTOR_SP_BACKWARD,
            negative_reverse_dists=ANCESTOR_SP_NEGATIVE_REVERSE_DISTS,
            threshold=ANCESTOR_SP_THRESHOLD),
        SiblingShortestPaths(
            forward=SIBLING_SP_FORWARD,
            backward=SIBLING_SP_BACKWARD,
            negative_reverse_dists=SIBLING_SP_NEGATIVE_REVERSE_DISTS,
            threshold=SIBLING_SP_THRESHOLD)
    ]

    db = DistanceBinning(NUM_BINS, N_FIXED_BINS,
                         ExponentialBinning(EXPONENTIAL_BINNING_GROWTH_FACTOR))

    distances_transformer = DistancesTransformer(distance_metrics, db)
    if len(vocabs) == 4:
        vocabulary_transformer = CodeSummarizationVocabularyTransformer(
            *vocabs)
    else:
        vocabulary_transformer = VocabularyTransformer(*vocabs)

    stage2 = stage1[0]
    if REMOVE_PUNCTUATION:
        stage2.remove_punctuation()
    stage2 = vocabulary_transformer(stage2)
    stage2 = distances_transformer(stage2)

    # Setup dataset to generate batch as input for model
    LIMIT_TOKENS = 1000
    token_distances = None
    if TokenDistancesTransform.name in model_config['data_transforms'][
            'relative_distances']:
        num_bins = data_manager.load_config()['num_bins']
        distance_binning_config = model_config['data_transforms'][
            'distance_binning']
        if distance_binning_config['type'] == 'exponential':
            trans_func = ExponentialBinning(
                distance_binning_config['growth_factor'])
        else:
            trans_func = EqualBinning()
        token_distances = TokenDistancesTransform(
            DistanceBinning(num_bins, distance_binning_config['n_fixed_bins'],
                            trans_func))
    if model_config['data_setup']['use_no_punctuation'] == True:
        dataset = CTCodeSummarizationDatasetNoPunctuation(
            data_manager,
            num_sub_tokens_output=NUM_SUB_TOKENS_METHOD_NAME,
            use_pointer_network=model_config['data_setup']
            ['use_pointer_network'],
            max_num_tokens=LIMIT_TOKENS,
            token_distances=token_distances)
    else:
        dataset = CTCodeSummarizationDataset(
            data_manager,
            num_sub_tokens_output=NUM_SUB_TOKENS_METHOD_NAME,
            use_pointer_network=model_config['data_setup']
            ['use_pointer_network'],
            max_num_tokens=LIMIT_TOKENS,
            token_distances=token_distances)

    # Hijack dataset to only contain user specified code snippet
    dataset.dataset = (stage2 for _ in range(1))
    processed_sample = next(dataset)
    batch = dataset.collate_fn([processed_sample])

    # Obtain model prediction
    output = model.forward_batch(batch)

    return output
Exemple #12
0
    def test_transformer_decoder(self):
        data_manager = CTPreprocessedDataManager(DATA_PATH_STAGE_2,
                                                 language='java-small',
                                                 partition='valid')
        word_vocab, _, _ = data_manager.load_vocabularies()

        dataset = CTCodeSummarizationDatasetNoPunctuation(
            data_manager, num_sub_tokens_output=6, use_pointer_network=True)
        dataloader = DataLoader(dataset,
                                batch_size=1,
                                collate_fn=dataset.collate_fn)

        iterator = iter(dataloader)
        batch = next(iterator)

        len_vocab = len(word_vocab)
        # Artificially set an out of vocabulary token
        batch.tokens[0][1][0] = word_vocab[UNKNOWN_TOKEN]

        model = self._create_transformer_decoder_model(
            data_manager, use_query_self_attention=True)
        optimizer = optim.Adam(model.parameters(), lr=5e-3)

        out = None
        # Pointer query self attention
        tq = tqdm(range(100))
        for i in tq:
            secret_label = torch.randint(0, len_vocab, (1, ))
            batch.extended_vocabulary_ids[0][5] = secret_label
            batch.labels[0] = torch.tensor([secret_label])
            batch.tokens[0][2:] = torch.randint(0, len_vocab,
                                                (batch.tokens.shape[1] - 2, 5))
            out = model.forward_batch(batch)
            out.loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            tq.set_postfix(loss=out.loss.item())

        self.assertLess(out.loss.item(), 0.01)
        self.assertGreater(topk_accuracy(1, out.logits, batch.labels), 0.95)
        self.assertLess(out.pointer_gates.exp().max(), 0.05)
        # Pointer network should point to artificial position in input
        self.assertGreater(
            out.pointer_attentions[:, 0, batch.labels[0, 0,
                                                      0].item()].exp().min(),
            0.95)

        model = self._create_transformer_decoder_model(
            data_manager, use_pointer_network=True)
        optimizer = optim.Adam(model.parameters(), lr=5e-3)

        out = None
        # Deterministic in-vocabulary label that is NOT part of the input => model uses Decoder to generate label
        tq = tqdm(range(100))
        for i in tq:
            secret_label = len_vocab - 2
            batch.labels[0] = torch.tensor([secret_label])
            batch.tokens[0][2:] = torch.randint(0, len_vocab,
                                                (batch.tokens.shape[1] - 2, 5))
            out = model.forward_batch(batch)
            out.loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            tq.set_postfix(loss=out.loss.item(),
                           top1=topk_accuracy(1, out.logits, batch.labels))

        self.assertLess(out.loss.item(), 0.5)
        self.assertGreater(topk_accuracy(1, out.logits, batch.labels), 0.95)
        self.assertGreater(out.pointer_gates.exp().min(), 0.95)
        model.eval()
        model.forward_batch(batch)
        self.assertLess(out.loss.item(), 0.5)
        self.assertGreater(topk_accuracy(1, out.logits, batch.labels), 0.95)
        self.assertGreater(out.pointer_gates.exp().min(), 0.95)

        model = self._create_transformer_decoder_model(data_manager)
        optimizer = optim.Adam(model.parameters(), lr=5e-3)

        out = None
        # Random out-of-vocabulary label that is part of the input => model uses Pointer
        tq = tqdm(range(100))
        for i in tq:
            secret_label = torch.randint(0, len_vocab, (1, ))
            batch.extended_vocabulary_ids[0][5] = secret_label
            batch.labels[0] = torch.tensor([secret_label])
            batch.tokens[0][2:] = torch.randint(0, len_vocab,
                                                (batch.tokens.shape[1] - 2, 5))
            out = model.forward_batch(batch)
            out.loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            tq.set_postfix(loss=out.loss.item())

        self.assertLess(out.loss.item(), 0.01)
        self.assertGreater(topk_accuracy(1, out.logits, batch.labels), 0.95)
        self.assertLess(out.pointer_gates.exp().max(), 0.05)
        # Pointer network should point to artificial position in input
        self.assertGreater(
            out.pointer_attentions[:, 0, batch.labels[0, 0,
                                                      0].item()].exp().min(),
            0.95)

        model = self._create_transformer_decoder_model(data_manager)
        optimizer = optim.Adam(model.parameters(), lr=5e-3)

        out = None
        # Random label that is NOT part of the input => no chance
        tq = tqdm(range(100))
        for i in tq:
            secret_label = torch.randint(0, len_vocab, (1, ))
            batch.labels[0] = torch.tensor([secret_label])
            batch.tokens[0][2:] = torch.randint(0, len_vocab,
                                                (batch.tokens.shape[1] - 2, 5))
            out = model.forward_batch(batch)
            out.loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            tq.set_postfix(loss=out.loss.item())

        self.assertGreater(out.loss.item(), 1)

        model = self._create_transformer_decoder_model(
            data_manager, use_pointer_network=False)
        optimizer = optim.Adam(model.parameters(), lr=5e-3)

        out = None
        # Random in-vocabulary label that is part of the input, but no pointer network => no chance
        tq = tqdm(range(100))
        for i in tq:
            secret_label = torch.randint(0, len_vocab, (1, ))
            batch.labels[0] = torch.tensor([secret_label])
            batch.tokens[0][2:] = torch.randint(0, len_vocab,
                                                (batch.tokens.shape[1] - 2, 5))
            out = model.forward_batch(batch)
            out.loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            tq.set_postfix(loss=out.loss.item())

        self.assertGreater(out.loss.item(), 1)
Exemple #13
0
    def load_model(self, run_id, snapshot_iteration, gpu=True):
        model_params = self.load_parameters(run_id, snapshot_iteration, gpu=gpu)
        config = self.load_config(run_id)
        model_config = self._prepare_model_config(config)

        language = config['data_setup']['language']
        use_only_ast = config['data_setup']['use_only_ast'] if 'use_only_ast' in config['data_setup'] else False
        data_manager = CTPreprocessedDataManager(DATA_PATH_STAGE_2, language)

        decoder_config = model_config['lm_decoder']

        vocabularies = data_manager.load_vocabularies()
        if len(vocabularies) == 3:
            word_vocab, token_type_vocab, node_type_vocab = vocabularies
            use_separate_vocab = False
        else:
            word_vocab, token_type_vocab, node_type_vocab, method_name_vocab = vocabularies
            use_separate_vocab = True

        encoder_config = model_config['lm_encoder']
        encoder_config['num_node_types'] = len(node_type_vocab)
        if use_only_ast:
            encoder_config['num_token_types'] = None
        else:
            encoder_config['num_token_types'] = len(token_type_vocab)
        encoder_config['vocab_size'] = len(word_vocab)
        encoder_config['transformer']['encoder_layer']['num_relative_distances'] = len(
            config['data_transforms']['relative_distances'])
        decoder_config['sos_id'] = word_vocab[SOS_TOKEN]
        if 'num_subtokens_output' in config['data_setup']:
            decoder_config['output_subtokens_per_token'] = config['data_setup']['num_subtokens_output']
        else:
            decoder_config['output_subtokens_per_token'] = NUM_SUB_TOKENS

        if 'use_pointer_network' in config['data_setup']:
            decoder_config['use_pointer_network'] = config['data_setup']['use_pointer_network']

        if ',' in data_manager.language:
            encoder_config['num_languages'] = len(data_manager.language.split(','))

        decoder_config['lm_encoder'] = encoder_config
        decoder_config['loss_fct'] = model_config['loss_fct']

        if use_separate_vocab:
            decoder_config['target_vocab_size'] = len(method_name_vocab)

        model = CodeTransformerDecoder(TransformerLMDecoderConfig(**decoder_config))

        try:
            model.load_state_dict(model_params)
        except RuntimeError as e:
            # In most cases, this is due to the legacy issue with encoder_self_attention
            model.add_module('encoder_self_attention',
                             MultiheadAttention(model.d_model, decoder_config['decoder_nhead'],
                                                dropout=decoder_config['decoder_dropout']))
            try:
                model.load_state_dict(model_params)
            except RuntimeError:
                decoder_config['concat_query_and_pointer'] = False
                model = CodeTransformerDecoder(TransformerLMDecoderConfig(**decoder_config))
                try:
                    model.load_state_dict(model_params)
                except:
                    decoder_config['concat_query_and_pointer'] = True
                    model = CodeTransformerDecoder(TransformerLMDecoderConfig(**decoder_config))
                    model.lm_encoder.language_embedding = None
                    try:
                        model.load_state_dict(model_params)
                    except:
                        decoder_config['concat_query_and_pointer'] = False
                        model = CodeTransformerDecoder(TransformerLMDecoderConfig(**decoder_config))

                        class PositionalEncodingMock(nn.Module):
                            def forward(self, x, position):
                                return x

                        model.positional_encoding = PositionalEncodingMock()
                        model.load_state_dict(model_params)

        return model