示例#1
0
    def test_only_positive(self):
        values = torch.arange(0, 9)
        n_bins = 8
        DB = DistanceBinning(n_bins, n_fixed=5)
        ixs, bins = DB(values.to(torch.long))

        self.assertTrue(bins[ixs].allclose(
            torch.tensor([0, 1, 2, 3, 4, 6, 6, 8, 8], dtype=torch.float)))
示例#2
0
    def test_continuous_distances(self):
        values = torch.tensor([0.1, 1.2, 2.3, 4.5, 4.5, 5.6, 6.7, 7.8, 8.9])
        n_bins = 8
        DB = DistanceBinning(n_bins, n_fixed=5)
        ixs, bins = DB(values)

        self.assertEqual(bins[0], 0.1)
        self.assertEqual(bins[-1], 8.9)
示例#3
0
 def test_without_unreachable(self):
     values = torch.arange(-5, 6)
     n_bins = 8
     DB = DistanceBinning(n_bins, n_fixed=3)
     ixs, bins = DB(values.to(torch.long))
     assert bins.allclose(
         torch.tensor([UNREACHABLE, -5, -3.5, -1, 0, 1, 3.5, 5]))
     assert bins[ixs].allclose(
         torch.tensor([-5, -5, -3.5, -3.5, -1, 0, 1, 3.5, 3.5, 5, 5]))
示例#4
0
    def test_2d_matrix(self):
        values = torch.arange(-6, 7, step=1).unsqueeze(0).repeat((7, 1))
        n_bins = 10
        DB = DistanceBinning(n_bins, n_fixed=5)
        ixs, bins = DB(values.to(torch.long))

        self.assertTrue(bins[ixs][0].allclose(
            torch.tensor([-6, -6, -4.5, -4.5, -2, -1, 0, 1, 2, 4.5, 4.5, 6, 6],
                         dtype=torch.float)))
示例#5
0
    def test_uneven_bins(self):
        values = torch.arange(-10, 11, step=1)
        n_bins = 7
        DB = DistanceBinning(n_bins, n_fixed=5)
        ixs, bins = DB(values.to(torch.long))

        self.assertTrue(
            bins.allclose(
                torch.tensor([UNREACHABLE, -2, -1, 0, 1, 2, 10],
                             dtype=torch.float)))
示例#6
0
 def test_all_fixed_with_unreachable_alternative(self):
     values_orig = torch.arange(-50, 51)
     values = torch.cat([values_orig, torch.tensor([1000])])
     n_bins = len(values)
     DB = DistanceBinning(n_bins, n_fixed=len(values) - 1)
     ixs, bins = DB(values.to(torch.long))
     assert bins.to(torch.long).allclose(
         torch.cat(
             [torch.tensor([UNREACHABLE], dtype=torch.long), values_orig]))
     assert bins[ixs].allclose(
         torch.cat([values_orig,
                    torch.tensor([UNREACHABLE])]).to(torch.float32))
示例#7
0
 def test_all_fixed_even_number(self):
     values = torch.arange(-5, 5)
     n_bins = len(values) + 1  # account for the UNREACHABLE bin
     DB = DistanceBinning(n_bins, n_fixed=len(values))
     ixs, bins = DB(values.to(torch.long))
     # bins should be:
     # [UNREACHABLE, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 4]
     assert bins.to(torch.long).allclose(
         torch.cat([torch.tensor([UNREACHABLE], dtype=torch.long), values]))
     # binned values should be:
     # [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
     assert bins[ixs].allclose(values.to(torch.float32))
示例#8
0
    def test_fewer_unique_values_than_bins(self):
        values = torch.arange(-10, 11, step=5).unsqueeze(0).repeat((3, 1))
        n_bins = 32
        DB = DistanceBinning(n_bins, n_fixed=5)
        ixs, bins = DB(values.to(torch.long))

        self.assertTrue(
            bins.allclose(
                torch.cat([
                    torch.tensor([UNREACHABLE, -10, -5, 0, 5, 10],
                                 dtype=torch.float),
                    torch.tensor([BIN_PADDING], dtype=torch.float).expand(26)
                ])))
示例#9
0
    def test_mixed_positive_negative(self):
        values = torch.tensor([5, -4, 3, -2, 1, 0, 8, 8, -8, -8])
        n_bins = 6
        DB = DistanceBinning(n_bins, n_fixed=3)
        ixs, bins = DB(values.to(torch.long))

        self.assertTrue(
            bins.allclose(
                torch.tensor([UNREACHABLE, -8, -1, 0, 1, 8],
                             dtype=torch.float)))
        self.assertTrue(bins[ixs].allclose(
            torch.tensor([8, -8, 8, -8, 1, 0, 8, 8, -8, -8],
                         dtype=torch.float)))
示例#10
0
 def _setup_distances_transformer(self):
     distance_metrics = [
         PersonalizedPageRank(threshold=self.ppr_threshold,
                              log=self.ppr_use_log,
                              alpha=self.ppr_alpha),
         ShortestPaths(threshold=self.sp_threshold),
         AncestorShortestPaths(
             forward=self.ancestor_sp_forward,
             backward=self.ancestor_sp_backward,
             negative_reverse_dists=self.ancestor_sp_negative_reverse_dists,
             threshold=self.ancestor_sp_threshold),
         SiblingShortestPaths(
             forward=self.sibling_sp_forward,
             backward=self.sibling_sp_backward,
             negative_reverse_dists=self.sibling_sp_negative_reverse_dists,
             threshold=self.sibling_sp_threshold)
     ]
     if self.exponential_binning:
         db = DistanceBinning(
             self.num_bins, self.n_fixed_bins,
             ExponentialBinning(self.exponential_binning_growth_factor))
     else:
         db = DistanceBinning(self.num_bins, self.n_fixed_bins)
     self.distances_transformer = DistancesTransformer(distance_metrics, db)
示例#11
0
 def test_all_fixed_with_unreachable(self):
     values_orig = torch.arange(-5, 6)
     values = torch.cat(
         [values_orig,
          torch.tensor([1000]),
          torch.tensor([-1000])])
     n_bins = len(values) - 1
     DB = DistanceBinning(n_bins, n_fixed=len(values) - 2)
     ixs, bins = DB(values.to(torch.long))
     # bins should be:
     # [UNREACHABLE, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5]
     assert bins.to(torch.long).allclose(
         torch.cat(
             [torch.tensor([UNREACHABLE], dtype=torch.long), values_orig]))
     # binned values should be:
     # [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, UNREACHABLE, UNREACHABLE]
     assert bins[ixs].allclose(
         torch.cat([
             values_orig,
             torch.tensor([UNREACHABLE]),
             torch.tensor([UNREACHABLE])
         ]).to(torch.float32))
示例#12
0
def predict_method_name(model,
                        model_config,
                        code_snippet: str,
                        method_name_place_holder='f'):
    language = model_config['data_setup']['language']

    # Build data manager and load vocabularies + configs
    data_manager = CTPreprocessedDataManager(DATA_PATH_STAGE_2,
                                             language,
                                             partition='train',
                                             shuffle=True)
    vocabs = data_manager.load_vocabularies()
    if len(vocabs) == 4:
        method_name_vocab = vocabs[-1]
    else:
        method_name_vocab = vocabs[0]
    word_vocab = vocabs[0]
    data_config = data_manager.load_config()

    # Stage 1 Preprocessing (Compute AST)
    lexer_language = 'java' if language in {
        'java-small', 'java-medium', 'java-large', 'java-small-pretrain',
        'java-pretrain'
    } else language
    preprocessor = CTStage1Preprocessor(lexer_language,
                                        allow_empty_methods=True)
    stage1 = preprocessor.process(
        [(method_name_place_holder, "", code_snippet)], 0)

    # Stage 2 Preprocessing (Compute Distance Matrices)
    distances_config = data_config['distances']
    PPR_ALPHA = distances_config['ppr_alpha']
    PPR_USE_LOG = distances_config['ppr_use_log']
    PPR_THRESHOLD = distances_config['ppr_threshold']

    SP_THRESHOLD = distances_config['sp_threshold']

    ANCESTOR_SP_FORWARD = distances_config['ancestor_sp_forward']
    ANCESTOR_SP_BACKWARD = distances_config['ancestor_sp_backward']
    ANCESTOR_SP_NEGATIVE_REVERSE_DISTS = distances_config[
        'ancestor_sp_negative_reverse_dists']
    ANCESTOR_SP_THRESHOLD = distances_config['ancestor_sp_threshold']

    SIBLING_SP_FORWARD = distances_config['sibling_sp_forward']
    SIBLING_SP_BACKWARD = distances_config['sibling_sp_backward']
    SIBLING_SP_NEGATIVE_REVERSE_DISTS = distances_config[
        'sibling_sp_negative_reverse_dists']
    SIBLING_SP_THRESHOLD = distances_config['sibling_sp_threshold']

    binning_config = data_config['binning']
    EXPONENTIAL_BINNING_GROWTH_FACTOR = binning_config[
        'exponential_binning_growth_factor']
    N_FIXED_BINS = binning_config['n_fixed_bins']
    NUM_BINS = binning_config[
        'num_bins']  # How many bins should be calculated for the values in distance matrices

    preprocessing_config = data_config['preprocessing']
    REMOVE_PUNCTUATION = preprocessing_config['remove_punctuation']

    distance_metrics = [
        PersonalizedPageRank(threshold=PPR_THRESHOLD,
                             log=PPR_USE_LOG,
                             alpha=PPR_ALPHA),
        ShortestPaths(threshold=SP_THRESHOLD),
        AncestorShortestPaths(
            forward=ANCESTOR_SP_FORWARD,
            backward=ANCESTOR_SP_BACKWARD,
            negative_reverse_dists=ANCESTOR_SP_NEGATIVE_REVERSE_DISTS,
            threshold=ANCESTOR_SP_THRESHOLD),
        SiblingShortestPaths(
            forward=SIBLING_SP_FORWARD,
            backward=SIBLING_SP_BACKWARD,
            negative_reverse_dists=SIBLING_SP_NEGATIVE_REVERSE_DISTS,
            threshold=SIBLING_SP_THRESHOLD)
    ]

    db = DistanceBinning(NUM_BINS, N_FIXED_BINS,
                         ExponentialBinning(EXPONENTIAL_BINNING_GROWTH_FACTOR))

    distances_transformer = DistancesTransformer(distance_metrics, db)
    if len(vocabs) == 4:
        vocabulary_transformer = CodeSummarizationVocabularyTransformer(
            *vocabs)
    else:
        vocabulary_transformer = VocabularyTransformer(*vocabs)

    stage2 = stage1[0]
    if REMOVE_PUNCTUATION:
        stage2.remove_punctuation()
    stage2 = vocabulary_transformer(stage2)
    stage2 = distances_transformer(stage2)

    # Setup dataset to generate batch as input for model
    LIMIT_TOKENS = 1000
    token_distances = None
    if TokenDistancesTransform.name in model_config['data_transforms'][
            'relative_distances']:
        num_bins = data_manager.load_config()['num_bins']
        distance_binning_config = model_config['data_transforms'][
            'distance_binning']
        if distance_binning_config['type'] == 'exponential':
            trans_func = ExponentialBinning(
                distance_binning_config['growth_factor'])
        else:
            trans_func = EqualBinning()
        token_distances = TokenDistancesTransform(
            DistanceBinning(num_bins, distance_binning_config['n_fixed_bins'],
                            trans_func))
    if model_config['data_setup']['use_no_punctuation'] == True:
        dataset = CTCodeSummarizationDatasetNoPunctuation(
            data_manager,
            num_sub_tokens_output=NUM_SUB_TOKENS_METHOD_NAME,
            use_pointer_network=model_config['data_setup']
            ['use_pointer_network'],
            max_num_tokens=LIMIT_TOKENS,
            token_distances=token_distances)
    else:
        dataset = CTCodeSummarizationDataset(
            data_manager,
            num_sub_tokens_output=NUM_SUB_TOKENS_METHOD_NAME,
            use_pointer_network=model_config['data_setup']
            ['use_pointer_network'],
            max_num_tokens=LIMIT_TOKENS,
            token_distances=token_distances)

    # Hijack dataset to only contain user specified code snippet
    dataset.dataset = (stage2 for _ in range(1))
    processed_sample = next(dataset)
    batch = dataset.collate_fn([processed_sample])

    # Obtain model prediction
    output = model.forward_batch(batch)

    return output
    def _init_data(self,
                   language,
                   use_validation=False,
                   mini_dataset=False,
                   num_sub_tokens=5,
                   num_subtokens_output=5,
                   use_only_ast=False,
                   sort_by_length=False,
                   shuffle=True,
                   use_pointer_network=False):

        self.num_sub_tokens = num_sub_tokens
        self.use_validation = use_validation
        self.use_pointer_network = use_pointer_network
        self.data_manager = CTBufferedDataManager(
            DATA_PATH_STAGE_2,
            language,
            shuffle=shuffle,
            infinite_loading=True,
            mini_dataset=mini_dataset,
            sort_by_length=sort_by_length)
        self.word_vocab, self.token_type_vocab, self.node_type_vocab = self.data_manager.load_vocabularies(
        )

        if ',' in language:
            self.num_languages = len(language.split(','))

        self.use_separate_vocab = False

        token_distances = None
        if TokenDistancesTransform.name in self.relative_distances:
            token_distances = TokenDistancesTransform(
                DistanceBinning(
                    self.data_manager.load_config()['binning']['num_bins'],
                    self.distance_binning['n_fixed_bins'],
                    self.distance_binning['trans_func']))

        self.dataset_train = CTCodeSummarizationDatasetEdgeTypes(
            self.data_manager,
            token_distances=token_distances,
            max_distance_mask=self.max_distance_mask,
            num_sub_tokens=num_sub_tokens,
            num_sub_tokens_output=num_subtokens_output,
            use_pointer_network=use_pointer_network)

        if self.use_validation:
            data_manager_validation = CTBufferedDataManager(
                DATA_PATH_STAGE_2,
                language,
                partition="valid",
                shuffle=True,
                infinite_loading=True,
                mini_dataset=mini_dataset)
            self.dataset_validation = CTCodeSummarizationDatasetEdgeTypes(
                data_manager_validation,
                token_distances=token_distances,
                max_distance_mask=self.max_distance_mask,
                num_sub_tokens=num_sub_tokens,
                num_sub_tokens_output=num_subtokens_output,
                use_pointer_network=use_pointer_network)

        self.dataset_validation_creator = \
            lambda infinite_loading: self._create_validation_dataset(DATA_PATH_STAGE_2, language, token_distances,
                                                                     num_sub_tokens, num_subtokens_output,
                                                                     infinite_loading=infinite_loading,
                                                                     use_pointer_network=use_pointer_network,
                                                                     filter_language=None, dataset_imbalance=None)
示例#14
0
    def _init_data(self,
                   language,
                   use_validation=False,
                   mini_dataset=False,
                   num_sub_tokens=NUM_SUB_TOKENS,
                   num_subtokens_output=NUM_SUB_TOKENS_METHOD_NAME,
                   use_only_ast=False,
                   use_no_punctuation=False,
                   use_pointer_network=False,
                   sort_by_length=False,
                   chunk_size=None,
                   filter_language=None,
                   dataset_imbalance=None,
                   mask_all_tokens=False):
        self.data_manager = CTBufferedDataManager(
            DATA_PATH_STAGE_2,
            language,
            shuffle=True,
            infinite_loading=True,
            mini_dataset=mini_dataset,
            sort_by_length=sort_by_length,
            chunk_size=chunk_size,
            filter_language=filter_language,
            dataset_imbalance=dataset_imbalance)
        vocabs = self.data_manager.load_vocabularies()
        if len(vocabs) == 4:
            self.word_vocab, self.token_type_vocab, self.node_type_vocab, self.method_name_vocab = vocabs
            self.use_separate_vocab = True
        else:
            self.word_vocab, self.token_type_vocab, self.node_type_vocab = vocabs
            self.use_separate_vocab = False

        if ',' in language:
            self.num_languages = len(language.split(','))

        token_distances = None
        if TokenDistancesTransform.name in self.relative_distances:
            print('Token distances will be calculated in dataset.')
            num_bins = self.data_manager.load_config()['binning']['num_bins']
            token_distances = TokenDistancesTransform(
                DistanceBinning(num_bins,
                                self.distance_binning['n_fixed_bins'],
                                self.distance_binning['trans_func']))

        self.use_only_ast = use_only_ast
        self.use_pointer_network = use_pointer_network
        self.num_sub_tokens = num_sub_tokens
        if use_only_ast:
            self.dataset_train = CTCodeSummarizationOnlyASTDataset(
                self.data_manager,
                token_distances=token_distances,
                max_distance_mask=self.max_distance_mask,
                num_sub_tokens=num_sub_tokens,
                num_sub_tokens_output=num_subtokens_output,
                use_pointer_network=use_pointer_network,
                mask_all_tokens=mask_all_tokens)
        elif use_no_punctuation:
            self.dataset_train = CTCodeSummarizationDatasetNoPunctuation(
                self.data_manager,
                token_distances=token_distances,
                max_distance_mask=self.max_distance_mask,
                num_sub_tokens=num_sub_tokens,
                num_sub_tokens_output=num_subtokens_output,
                use_pointer_network=use_pointer_network)
        else:
            self.dataset_train = CTCodeSummarizationDataset(
                self.data_manager,
                token_distances=token_distances,
                max_distance_mask=self.max_distance_mask,
                num_sub_tokens=num_sub_tokens,
                num_sub_tokens_output=num_subtokens_output,
                use_pointer_network=use_pointer_network)

        self.use_validation = use_validation
        if self.use_validation:
            data_manager_validation = CTBufferedDataManager(
                DATA_PATH_STAGE_2,
                language,
                partition="valid",
                shuffle=True,
                infinite_loading=True,
                mini_dataset=mini_dataset,
                filter_language=filter_language,
                dataset_imbalance=dataset_imbalance)
            if use_only_ast:
                self.dataset_validation = CTCodeSummarizationOnlyASTDataset(
                    data_manager_validation,
                    token_distances=token_distances,
                    max_distance_mask=self.max_distance_mask,
                    num_sub_tokens=num_sub_tokens,
                    num_sub_tokens_output=num_subtokens_output,
                    use_pointer_network=use_pointer_network,
                    mask_all_tokens=mask_all_tokens)
            elif use_no_punctuation:
                self.dataset_validation = CTCodeSummarizationDatasetNoPunctuation(
                    data_manager_validation,
                    token_distances=token_distances,
                    max_distance_mask=self.max_distance_mask,
                    num_sub_tokens=num_sub_tokens,
                    num_sub_tokens_output=num_subtokens_output,
                    use_pointer_network=use_pointer_network)
            else:
                self.dataset_validation = CTCodeSummarizationDataset(
                    data_manager_validation,
                    token_distances=token_distances,
                    max_distance_mask=self.max_distance_mask,
                    num_sub_tokens=num_sub_tokens,
                    num_sub_tokens_output=num_subtokens_output,
                    use_pointer_network=use_pointer_network)

        self.dataset_validation_creator = \
            lambda infinite_loading: self._create_validation_dataset(DATA_PATH_STAGE_2,
                                                                     language,
                                                                     use_only_ast,
                                                                     use_no_punctuation,
                                                                     token_distances,
                                                                     num_sub_tokens,
                                                                     num_subtokens_output,
                                                                     infinite_loading,
                                                                     use_pointer_network,
                                                                     filter_language,
                                                                     dataset_imbalance,
                                                                     mask_all_tokens)
示例#15
0
def make_batch_from_sample(stage2_sample: CTStage2Sample, model_config,
                           model_type):
    assert isinstance(stage2_sample.token_mapping,
                      dict), f"Please re-generate the sample"
    data_manager = CTPreprocessedDataManager(
        DATA_PATH_STAGE_2,
        model_config['data_setup']['language'],
        partition='train',
        shuffle=True)

    # Setup dataset to generate batch as input for model
    LIMIT_TOKENS = 1000
    token_distances = None
    if TokenDistancesTransform.name in model_config['data_transforms'][
            'relative_distances']:
        num_bins = data_manager.load_config()['num_bins']
        distance_binning_config = model_config['data_transforms'][
            'distance_binning']
        if distance_binning_config['type'] == 'exponential':
            trans_func = ExponentialBinning(
                distance_binning_config['growth_factor'])
        else:
            trans_func = EqualBinning()
        token_distances = TokenDistancesTransform(
            DistanceBinning(num_bins, distance_binning_config['n_fixed_bins'],
                            trans_func))

    use_pointer_network = model_config['data_setup']['use_pointer_network']
    if model_type in {'great'}:
        dataset_type = 'great'
    elif 'use_only_ast' in model_config['data_setup'] and model_config[
            'data_setup']['use_only_ast']:
        dataset_type = 'only_ast'
    elif 'use_no_punctuation' in model_config['data_setup'] and model_config[
            'data_setup']['use_no_punctuation']:
        dataset_type = 'no_punctuation'
    else:
        dataset_type = 'regular'

    if dataset_type == 'great':
        dataset = CTCodeSummarizationDatasetEdgeTypes(
            data_manager,
            num_sub_tokens_output=NUM_SUB_TOKENS_METHOD_NAME,
            use_pointer_network=use_pointer_network,
            token_distances=token_distances,
            max_num_tokens=LIMIT_TOKENS)
    elif dataset_type == 'regular':
        dataset = CTCodeSummarizationDataset(
            data_manager,
            num_sub_tokens_output=NUM_SUB_TOKENS_METHOD_NAME,
            use_pointer_network=use_pointer_network,
            max_num_tokens=LIMIT_TOKENS,
            token_distances=token_distances)
    elif dataset_type == 'no_punctuation':
        dataset = CTCodeSummarizationDatasetNoPunctuation(
            data_manager,
            num_sub_tokens_output=NUM_SUB_TOKENS_METHOD_NAME,
            use_pointer_network=use_pointer_network,
            max_num_tokens=LIMIT_TOKENS,
            token_distances=token_distances)
    elif dataset_type == 'only_ast':
        dataset = CTCodeSummarizationOnlyASTDataset(
            data_manager,
            num_sub_tokens_output=NUM_SUB_TOKENS_METHOD_NAME,
            use_pointer_network=use_pointer_network,
            max_num_tokens=LIMIT_TOKENS,
            token_distances=token_distances)
    else:
        raise ValueError(f"Unknown dataset type: {dataset_type}")

    # Hijack dataset to only contain user specified code snippet
    dataset.dataset = (stage2_sample for _ in range(1))
    processed_sample = next(dataset)
    batch = dataset.collate_fn([processed_sample])

    return batch
示例#16
0
    vocabularies = data_manager.load_vocabularies()
    if len(vocabularies) == 3:
        word_vocab, _, _ = vocabularies
    else:
        word_vocab, _, _, _ = vocabularies

    token_distances = None
    if TokenDistancesTransform.name in config['data_transforms']['relative_distances']:
        num_bins = data_manager.load_config()['binning']['num_bins']
        distance_binning_config = config['data_transforms']['distance_binning']
        if distance_binning_config['type'] == 'exponential':
            trans_func = ExponentialBinning(distance_binning_config['growth_factor'])
        else:
            trans_func = EqualBinning()
        token_distances = TokenDistancesTransform(
            DistanceBinning(num_bins, distance_binning_config['n_fixed_bins'], trans_func))

    use_pointer_network = config['data_setup']['use_pointer_network']
    if args.model in {'great'}:
        dataset_type = 'great'
    elif 'use_only_ast' in config['data_setup'] and config['data_setup']['use_only_ast']:
        dataset_type = 'only_ast'
    elif 'use_no_punctuation' in config['data_setup'] and config['data_setup']['use_no_punctuation']:
        dataset_type = 'no_punctuation'
    else:
        dataset_type = 'regular'

    print(
        f"Evaluating model snapshot-{args.snapshot_iteration} from run {args.run_id} on {config['data_setup']['language']} partition {args.partition}")
    print(f"gpu: {not args.no_gpu}")
    print(f"dataset_type: {dataset_type}")
示例#17
0
    def _init_data(self, language, num_predict, use_validation=False, mini_dataset=False,
                   use_no_punctuation=False, use_pointer_network=False, sort_by_length=False, shuffle=True,
                   chunk_size=None, filter_language=None, dataset_imbalance=None, num_sub_tokens=NUM_SUB_TOKENS):
        self.data_manager = CTBufferedDataManager(DATA_PATH_STAGE_2, language, shuffle=shuffle,
                                                  infinite_loading=True,
                                                  mini_dataset=mini_dataset, size_load_buffer=10000,
                                                  sort_by_length=sort_by_length, chunk_size=chunk_size,
                                                  filter_language=filter_language, dataset_imbalance=dataset_imbalance)
        self.word_vocab, self.token_type_vocab, self.node_type_vocab = self.data_manager.load_vocabularies()

        token_distances = None
        if TokenDistancesTransform.name in self.relative_distances:
            num_bins = self.data_manager.load_config()['binning']['num_bins']
            token_distances = TokenDistancesTransform(
                DistanceBinning(num_bins, self.distance_binning['n_fixed_bins'], self.distance_binning['trans_func']))

        self.num_predict = num_predict
        self.use_pointer_network = use_pointer_network
        self.use_separate_vocab = False  # For language modeling we always only operate on the method body vocabulary

        if use_no_punctuation:
            self.dataset_train = CTLanguageModelingDatasetNoPunctuation(self.data_manager,
                                                                        token_distances=token_distances,
                                                                        max_distance_mask=self.max_distance_mask,
                                                                        num_labels_per_sample=num_predict,
                                                                        use_pointer_network=use_pointer_network,
                                                                        num_sub_tokens=num_sub_tokens)
        else:
            self.dataset_train = CTLanguageModelingDataset(self.data_manager, token_distances=token_distances,
                                                           max_distance_mask=self.max_distance_mask,
                                                           num_labels_per_sample=num_predict,
                                                           use_pointer_network=use_pointer_network,
                                                           num_sub_tokens=num_sub_tokens)

        self.use_validation = use_validation
        if self.use_validation:
            data_manager_validation = CTBufferedDataManager(DATA_PATH_STAGE_2, language, partition="valid",
                                                            shuffle=True, infinite_loading=True,
                                                            mini_dataset=mini_dataset, size_load_buffer=10000,
                                                            filter_language=filter_language,
                                                            dataset_imbalance=dataset_imbalance)
            if use_no_punctuation:
                self.dataset_validation = CTLanguageModelingDatasetNoPunctuation(data_manager_validation,
                                                                                 token_distances=token_distances,
                                                                                 max_distance_mask=self.max_distance_mask,
                                                                                 num_labels_per_sample=num_predict,
                                                                                 use_pointer_network=use_pointer_network,
                                                                                 num_sub_tokens=num_sub_tokens)
            else:
                self.dataset_validation = CTLanguageModelingDataset(data_manager_validation,
                                                                    token_distances=token_distances,
                                                                    max_distance_mask=self.max_distance_mask,
                                                                    num_labels_per_sample=num_predict,
                                                                    use_pointer_network=use_pointer_network,
                                                                    num_sub_tokens=num_sub_tokens)
        self.dataset_validation_creator = \
            lambda infinite_loading: self._create_validation_dataset(DATA_PATH_STAGE_2,
                                                                     language,
                                                                     use_no_punctuation,
                                                                     token_distances,
                                                                     infinite_loading,
                                                                     num_predict,
                                                                     use_pointer_network,
                                                                     filter_language,
                                                                     dataset_imbalance,
                                                                     num_sub_tokens)