def test_only_positive(self): values = torch.arange(0, 9) n_bins = 8 DB = DistanceBinning(n_bins, n_fixed=5) ixs, bins = DB(values.to(torch.long)) self.assertTrue(bins[ixs].allclose( torch.tensor([0, 1, 2, 3, 4, 6, 6, 8, 8], dtype=torch.float)))
def test_continuous_distances(self): values = torch.tensor([0.1, 1.2, 2.3, 4.5, 4.5, 5.6, 6.7, 7.8, 8.9]) n_bins = 8 DB = DistanceBinning(n_bins, n_fixed=5) ixs, bins = DB(values) self.assertEqual(bins[0], 0.1) self.assertEqual(bins[-1], 8.9)
def test_without_unreachable(self): values = torch.arange(-5, 6) n_bins = 8 DB = DistanceBinning(n_bins, n_fixed=3) ixs, bins = DB(values.to(torch.long)) assert bins.allclose( torch.tensor([UNREACHABLE, -5, -3.5, -1, 0, 1, 3.5, 5])) assert bins[ixs].allclose( torch.tensor([-5, -5, -3.5, -3.5, -1, 0, 1, 3.5, 3.5, 5, 5]))
def test_2d_matrix(self): values = torch.arange(-6, 7, step=1).unsqueeze(0).repeat((7, 1)) n_bins = 10 DB = DistanceBinning(n_bins, n_fixed=5) ixs, bins = DB(values.to(torch.long)) self.assertTrue(bins[ixs][0].allclose( torch.tensor([-6, -6, -4.5, -4.5, -2, -1, 0, 1, 2, 4.5, 4.5, 6, 6], dtype=torch.float)))
def test_uneven_bins(self): values = torch.arange(-10, 11, step=1) n_bins = 7 DB = DistanceBinning(n_bins, n_fixed=5) ixs, bins = DB(values.to(torch.long)) self.assertTrue( bins.allclose( torch.tensor([UNREACHABLE, -2, -1, 0, 1, 2, 10], dtype=torch.float)))
def test_all_fixed_with_unreachable_alternative(self): values_orig = torch.arange(-50, 51) values = torch.cat([values_orig, torch.tensor([1000])]) n_bins = len(values) DB = DistanceBinning(n_bins, n_fixed=len(values) - 1) ixs, bins = DB(values.to(torch.long)) assert bins.to(torch.long).allclose( torch.cat( [torch.tensor([UNREACHABLE], dtype=torch.long), values_orig])) assert bins[ixs].allclose( torch.cat([values_orig, torch.tensor([UNREACHABLE])]).to(torch.float32))
def test_all_fixed_even_number(self): values = torch.arange(-5, 5) n_bins = len(values) + 1 # account for the UNREACHABLE bin DB = DistanceBinning(n_bins, n_fixed=len(values)) ixs, bins = DB(values.to(torch.long)) # bins should be: # [UNREACHABLE, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 4] assert bins.to(torch.long).allclose( torch.cat([torch.tensor([UNREACHABLE], dtype=torch.long), values])) # binned values should be: # [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4] assert bins[ixs].allclose(values.to(torch.float32))
def test_fewer_unique_values_than_bins(self): values = torch.arange(-10, 11, step=5).unsqueeze(0).repeat((3, 1)) n_bins = 32 DB = DistanceBinning(n_bins, n_fixed=5) ixs, bins = DB(values.to(torch.long)) self.assertTrue( bins.allclose( torch.cat([ torch.tensor([UNREACHABLE, -10, -5, 0, 5, 10], dtype=torch.float), torch.tensor([BIN_PADDING], dtype=torch.float).expand(26) ])))
def test_mixed_positive_negative(self): values = torch.tensor([5, -4, 3, -2, 1, 0, 8, 8, -8, -8]) n_bins = 6 DB = DistanceBinning(n_bins, n_fixed=3) ixs, bins = DB(values.to(torch.long)) self.assertTrue( bins.allclose( torch.tensor([UNREACHABLE, -8, -1, 0, 1, 8], dtype=torch.float))) self.assertTrue(bins[ixs].allclose( torch.tensor([8, -8, 8, -8, 1, 0, 8, 8, -8, -8], dtype=torch.float)))
def _setup_distances_transformer(self): distance_metrics = [ PersonalizedPageRank(threshold=self.ppr_threshold, log=self.ppr_use_log, alpha=self.ppr_alpha), ShortestPaths(threshold=self.sp_threshold), AncestorShortestPaths( forward=self.ancestor_sp_forward, backward=self.ancestor_sp_backward, negative_reverse_dists=self.ancestor_sp_negative_reverse_dists, threshold=self.ancestor_sp_threshold), SiblingShortestPaths( forward=self.sibling_sp_forward, backward=self.sibling_sp_backward, negative_reverse_dists=self.sibling_sp_negative_reverse_dists, threshold=self.sibling_sp_threshold) ] if self.exponential_binning: db = DistanceBinning( self.num_bins, self.n_fixed_bins, ExponentialBinning(self.exponential_binning_growth_factor)) else: db = DistanceBinning(self.num_bins, self.n_fixed_bins) self.distances_transformer = DistancesTransformer(distance_metrics, db)
def test_all_fixed_with_unreachable(self): values_orig = torch.arange(-5, 6) values = torch.cat( [values_orig, torch.tensor([1000]), torch.tensor([-1000])]) n_bins = len(values) - 1 DB = DistanceBinning(n_bins, n_fixed=len(values) - 2) ixs, bins = DB(values.to(torch.long)) # bins should be: # [UNREACHABLE, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5] assert bins.to(torch.long).allclose( torch.cat( [torch.tensor([UNREACHABLE], dtype=torch.long), values_orig])) # binned values should be: # [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, UNREACHABLE, UNREACHABLE] assert bins[ixs].allclose( torch.cat([ values_orig, torch.tensor([UNREACHABLE]), torch.tensor([UNREACHABLE]) ]).to(torch.float32))
def predict_method_name(model, model_config, code_snippet: str, method_name_place_holder='f'): language = model_config['data_setup']['language'] # Build data manager and load vocabularies + configs data_manager = CTPreprocessedDataManager(DATA_PATH_STAGE_2, language, partition='train', shuffle=True) vocabs = data_manager.load_vocabularies() if len(vocabs) == 4: method_name_vocab = vocabs[-1] else: method_name_vocab = vocabs[0] word_vocab = vocabs[0] data_config = data_manager.load_config() # Stage 1 Preprocessing (Compute AST) lexer_language = 'java' if language in { 'java-small', 'java-medium', 'java-large', 'java-small-pretrain', 'java-pretrain' } else language preprocessor = CTStage1Preprocessor(lexer_language, allow_empty_methods=True) stage1 = preprocessor.process( [(method_name_place_holder, "", code_snippet)], 0) # Stage 2 Preprocessing (Compute Distance Matrices) distances_config = data_config['distances'] PPR_ALPHA = distances_config['ppr_alpha'] PPR_USE_LOG = distances_config['ppr_use_log'] PPR_THRESHOLD = distances_config['ppr_threshold'] SP_THRESHOLD = distances_config['sp_threshold'] ANCESTOR_SP_FORWARD = distances_config['ancestor_sp_forward'] ANCESTOR_SP_BACKWARD = distances_config['ancestor_sp_backward'] ANCESTOR_SP_NEGATIVE_REVERSE_DISTS = distances_config[ 'ancestor_sp_negative_reverse_dists'] ANCESTOR_SP_THRESHOLD = distances_config['ancestor_sp_threshold'] SIBLING_SP_FORWARD = distances_config['sibling_sp_forward'] SIBLING_SP_BACKWARD = distances_config['sibling_sp_backward'] SIBLING_SP_NEGATIVE_REVERSE_DISTS = distances_config[ 'sibling_sp_negative_reverse_dists'] SIBLING_SP_THRESHOLD = distances_config['sibling_sp_threshold'] binning_config = data_config['binning'] EXPONENTIAL_BINNING_GROWTH_FACTOR = binning_config[ 'exponential_binning_growth_factor'] N_FIXED_BINS = binning_config['n_fixed_bins'] NUM_BINS = binning_config[ 'num_bins'] # How many bins should be calculated for the values in distance matrices preprocessing_config = data_config['preprocessing'] REMOVE_PUNCTUATION = preprocessing_config['remove_punctuation'] distance_metrics = [ PersonalizedPageRank(threshold=PPR_THRESHOLD, log=PPR_USE_LOG, alpha=PPR_ALPHA), ShortestPaths(threshold=SP_THRESHOLD), AncestorShortestPaths( forward=ANCESTOR_SP_FORWARD, backward=ANCESTOR_SP_BACKWARD, negative_reverse_dists=ANCESTOR_SP_NEGATIVE_REVERSE_DISTS, threshold=ANCESTOR_SP_THRESHOLD), SiblingShortestPaths( forward=SIBLING_SP_FORWARD, backward=SIBLING_SP_BACKWARD, negative_reverse_dists=SIBLING_SP_NEGATIVE_REVERSE_DISTS, threshold=SIBLING_SP_THRESHOLD) ] db = DistanceBinning(NUM_BINS, N_FIXED_BINS, ExponentialBinning(EXPONENTIAL_BINNING_GROWTH_FACTOR)) distances_transformer = DistancesTransformer(distance_metrics, db) if len(vocabs) == 4: vocabulary_transformer = CodeSummarizationVocabularyTransformer( *vocabs) else: vocabulary_transformer = VocabularyTransformer(*vocabs) stage2 = stage1[0] if REMOVE_PUNCTUATION: stage2.remove_punctuation() stage2 = vocabulary_transformer(stage2) stage2 = distances_transformer(stage2) # Setup dataset to generate batch as input for model LIMIT_TOKENS = 1000 token_distances = None if TokenDistancesTransform.name in model_config['data_transforms'][ 'relative_distances']: num_bins = data_manager.load_config()['num_bins'] distance_binning_config = model_config['data_transforms'][ 'distance_binning'] if distance_binning_config['type'] == 'exponential': trans_func = ExponentialBinning( distance_binning_config['growth_factor']) else: trans_func = EqualBinning() token_distances = TokenDistancesTransform( DistanceBinning(num_bins, distance_binning_config['n_fixed_bins'], trans_func)) if model_config['data_setup']['use_no_punctuation'] == True: dataset = CTCodeSummarizationDatasetNoPunctuation( data_manager, num_sub_tokens_output=NUM_SUB_TOKENS_METHOD_NAME, use_pointer_network=model_config['data_setup'] ['use_pointer_network'], max_num_tokens=LIMIT_TOKENS, token_distances=token_distances) else: dataset = CTCodeSummarizationDataset( data_manager, num_sub_tokens_output=NUM_SUB_TOKENS_METHOD_NAME, use_pointer_network=model_config['data_setup'] ['use_pointer_network'], max_num_tokens=LIMIT_TOKENS, token_distances=token_distances) # Hijack dataset to only contain user specified code snippet dataset.dataset = (stage2 for _ in range(1)) processed_sample = next(dataset) batch = dataset.collate_fn([processed_sample]) # Obtain model prediction output = model.forward_batch(batch) return output
def _init_data(self, language, use_validation=False, mini_dataset=False, num_sub_tokens=5, num_subtokens_output=5, use_only_ast=False, sort_by_length=False, shuffle=True, use_pointer_network=False): self.num_sub_tokens = num_sub_tokens self.use_validation = use_validation self.use_pointer_network = use_pointer_network self.data_manager = CTBufferedDataManager( DATA_PATH_STAGE_2, language, shuffle=shuffle, infinite_loading=True, mini_dataset=mini_dataset, sort_by_length=sort_by_length) self.word_vocab, self.token_type_vocab, self.node_type_vocab = self.data_manager.load_vocabularies( ) if ',' in language: self.num_languages = len(language.split(',')) self.use_separate_vocab = False token_distances = None if TokenDistancesTransform.name in self.relative_distances: token_distances = TokenDistancesTransform( DistanceBinning( self.data_manager.load_config()['binning']['num_bins'], self.distance_binning['n_fixed_bins'], self.distance_binning['trans_func'])) self.dataset_train = CTCodeSummarizationDatasetEdgeTypes( self.data_manager, token_distances=token_distances, max_distance_mask=self.max_distance_mask, num_sub_tokens=num_sub_tokens, num_sub_tokens_output=num_subtokens_output, use_pointer_network=use_pointer_network) if self.use_validation: data_manager_validation = CTBufferedDataManager( DATA_PATH_STAGE_2, language, partition="valid", shuffle=True, infinite_loading=True, mini_dataset=mini_dataset) self.dataset_validation = CTCodeSummarizationDatasetEdgeTypes( data_manager_validation, token_distances=token_distances, max_distance_mask=self.max_distance_mask, num_sub_tokens=num_sub_tokens, num_sub_tokens_output=num_subtokens_output, use_pointer_network=use_pointer_network) self.dataset_validation_creator = \ lambda infinite_loading: self._create_validation_dataset(DATA_PATH_STAGE_2, language, token_distances, num_sub_tokens, num_subtokens_output, infinite_loading=infinite_loading, use_pointer_network=use_pointer_network, filter_language=None, dataset_imbalance=None)
def _init_data(self, language, use_validation=False, mini_dataset=False, num_sub_tokens=NUM_SUB_TOKENS, num_subtokens_output=NUM_SUB_TOKENS_METHOD_NAME, use_only_ast=False, use_no_punctuation=False, use_pointer_network=False, sort_by_length=False, chunk_size=None, filter_language=None, dataset_imbalance=None, mask_all_tokens=False): self.data_manager = CTBufferedDataManager( DATA_PATH_STAGE_2, language, shuffle=True, infinite_loading=True, mini_dataset=mini_dataset, sort_by_length=sort_by_length, chunk_size=chunk_size, filter_language=filter_language, dataset_imbalance=dataset_imbalance) vocabs = self.data_manager.load_vocabularies() if len(vocabs) == 4: self.word_vocab, self.token_type_vocab, self.node_type_vocab, self.method_name_vocab = vocabs self.use_separate_vocab = True else: self.word_vocab, self.token_type_vocab, self.node_type_vocab = vocabs self.use_separate_vocab = False if ',' in language: self.num_languages = len(language.split(',')) token_distances = None if TokenDistancesTransform.name in self.relative_distances: print('Token distances will be calculated in dataset.') num_bins = self.data_manager.load_config()['binning']['num_bins'] token_distances = TokenDistancesTransform( DistanceBinning(num_bins, self.distance_binning['n_fixed_bins'], self.distance_binning['trans_func'])) self.use_only_ast = use_only_ast self.use_pointer_network = use_pointer_network self.num_sub_tokens = num_sub_tokens if use_only_ast: self.dataset_train = CTCodeSummarizationOnlyASTDataset( self.data_manager, token_distances=token_distances, max_distance_mask=self.max_distance_mask, num_sub_tokens=num_sub_tokens, num_sub_tokens_output=num_subtokens_output, use_pointer_network=use_pointer_network, mask_all_tokens=mask_all_tokens) elif use_no_punctuation: self.dataset_train = CTCodeSummarizationDatasetNoPunctuation( self.data_manager, token_distances=token_distances, max_distance_mask=self.max_distance_mask, num_sub_tokens=num_sub_tokens, num_sub_tokens_output=num_subtokens_output, use_pointer_network=use_pointer_network) else: self.dataset_train = CTCodeSummarizationDataset( self.data_manager, token_distances=token_distances, max_distance_mask=self.max_distance_mask, num_sub_tokens=num_sub_tokens, num_sub_tokens_output=num_subtokens_output, use_pointer_network=use_pointer_network) self.use_validation = use_validation if self.use_validation: data_manager_validation = CTBufferedDataManager( DATA_PATH_STAGE_2, language, partition="valid", shuffle=True, infinite_loading=True, mini_dataset=mini_dataset, filter_language=filter_language, dataset_imbalance=dataset_imbalance) if use_only_ast: self.dataset_validation = CTCodeSummarizationOnlyASTDataset( data_manager_validation, token_distances=token_distances, max_distance_mask=self.max_distance_mask, num_sub_tokens=num_sub_tokens, num_sub_tokens_output=num_subtokens_output, use_pointer_network=use_pointer_network, mask_all_tokens=mask_all_tokens) elif use_no_punctuation: self.dataset_validation = CTCodeSummarizationDatasetNoPunctuation( data_manager_validation, token_distances=token_distances, max_distance_mask=self.max_distance_mask, num_sub_tokens=num_sub_tokens, num_sub_tokens_output=num_subtokens_output, use_pointer_network=use_pointer_network) else: self.dataset_validation = CTCodeSummarizationDataset( data_manager_validation, token_distances=token_distances, max_distance_mask=self.max_distance_mask, num_sub_tokens=num_sub_tokens, num_sub_tokens_output=num_subtokens_output, use_pointer_network=use_pointer_network) self.dataset_validation_creator = \ lambda infinite_loading: self._create_validation_dataset(DATA_PATH_STAGE_2, language, use_only_ast, use_no_punctuation, token_distances, num_sub_tokens, num_subtokens_output, infinite_loading, use_pointer_network, filter_language, dataset_imbalance, mask_all_tokens)
def make_batch_from_sample(stage2_sample: CTStage2Sample, model_config, model_type): assert isinstance(stage2_sample.token_mapping, dict), f"Please re-generate the sample" data_manager = CTPreprocessedDataManager( DATA_PATH_STAGE_2, model_config['data_setup']['language'], partition='train', shuffle=True) # Setup dataset to generate batch as input for model LIMIT_TOKENS = 1000 token_distances = None if TokenDistancesTransform.name in model_config['data_transforms'][ 'relative_distances']: num_bins = data_manager.load_config()['num_bins'] distance_binning_config = model_config['data_transforms'][ 'distance_binning'] if distance_binning_config['type'] == 'exponential': trans_func = ExponentialBinning( distance_binning_config['growth_factor']) else: trans_func = EqualBinning() token_distances = TokenDistancesTransform( DistanceBinning(num_bins, distance_binning_config['n_fixed_bins'], trans_func)) use_pointer_network = model_config['data_setup']['use_pointer_network'] if model_type in {'great'}: dataset_type = 'great' elif 'use_only_ast' in model_config['data_setup'] and model_config[ 'data_setup']['use_only_ast']: dataset_type = 'only_ast' elif 'use_no_punctuation' in model_config['data_setup'] and model_config[ 'data_setup']['use_no_punctuation']: dataset_type = 'no_punctuation' else: dataset_type = 'regular' if dataset_type == 'great': dataset = CTCodeSummarizationDatasetEdgeTypes( data_manager, num_sub_tokens_output=NUM_SUB_TOKENS_METHOD_NAME, use_pointer_network=use_pointer_network, token_distances=token_distances, max_num_tokens=LIMIT_TOKENS) elif dataset_type == 'regular': dataset = CTCodeSummarizationDataset( data_manager, num_sub_tokens_output=NUM_SUB_TOKENS_METHOD_NAME, use_pointer_network=use_pointer_network, max_num_tokens=LIMIT_TOKENS, token_distances=token_distances) elif dataset_type == 'no_punctuation': dataset = CTCodeSummarizationDatasetNoPunctuation( data_manager, num_sub_tokens_output=NUM_SUB_TOKENS_METHOD_NAME, use_pointer_network=use_pointer_network, max_num_tokens=LIMIT_TOKENS, token_distances=token_distances) elif dataset_type == 'only_ast': dataset = CTCodeSummarizationOnlyASTDataset( data_manager, num_sub_tokens_output=NUM_SUB_TOKENS_METHOD_NAME, use_pointer_network=use_pointer_network, max_num_tokens=LIMIT_TOKENS, token_distances=token_distances) else: raise ValueError(f"Unknown dataset type: {dataset_type}") # Hijack dataset to only contain user specified code snippet dataset.dataset = (stage2_sample for _ in range(1)) processed_sample = next(dataset) batch = dataset.collate_fn([processed_sample]) return batch
vocabularies = data_manager.load_vocabularies() if len(vocabularies) == 3: word_vocab, _, _ = vocabularies else: word_vocab, _, _, _ = vocabularies token_distances = None if TokenDistancesTransform.name in config['data_transforms']['relative_distances']: num_bins = data_manager.load_config()['binning']['num_bins'] distance_binning_config = config['data_transforms']['distance_binning'] if distance_binning_config['type'] == 'exponential': trans_func = ExponentialBinning(distance_binning_config['growth_factor']) else: trans_func = EqualBinning() token_distances = TokenDistancesTransform( DistanceBinning(num_bins, distance_binning_config['n_fixed_bins'], trans_func)) use_pointer_network = config['data_setup']['use_pointer_network'] if args.model in {'great'}: dataset_type = 'great' elif 'use_only_ast' in config['data_setup'] and config['data_setup']['use_only_ast']: dataset_type = 'only_ast' elif 'use_no_punctuation' in config['data_setup'] and config['data_setup']['use_no_punctuation']: dataset_type = 'no_punctuation' else: dataset_type = 'regular' print( f"Evaluating model snapshot-{args.snapshot_iteration} from run {args.run_id} on {config['data_setup']['language']} partition {args.partition}") print(f"gpu: {not args.no_gpu}") print(f"dataset_type: {dataset_type}")
def _init_data(self, language, num_predict, use_validation=False, mini_dataset=False, use_no_punctuation=False, use_pointer_network=False, sort_by_length=False, shuffle=True, chunk_size=None, filter_language=None, dataset_imbalance=None, num_sub_tokens=NUM_SUB_TOKENS): self.data_manager = CTBufferedDataManager(DATA_PATH_STAGE_2, language, shuffle=shuffle, infinite_loading=True, mini_dataset=mini_dataset, size_load_buffer=10000, sort_by_length=sort_by_length, chunk_size=chunk_size, filter_language=filter_language, dataset_imbalance=dataset_imbalance) self.word_vocab, self.token_type_vocab, self.node_type_vocab = self.data_manager.load_vocabularies() token_distances = None if TokenDistancesTransform.name in self.relative_distances: num_bins = self.data_manager.load_config()['binning']['num_bins'] token_distances = TokenDistancesTransform( DistanceBinning(num_bins, self.distance_binning['n_fixed_bins'], self.distance_binning['trans_func'])) self.num_predict = num_predict self.use_pointer_network = use_pointer_network self.use_separate_vocab = False # For language modeling we always only operate on the method body vocabulary if use_no_punctuation: self.dataset_train = CTLanguageModelingDatasetNoPunctuation(self.data_manager, token_distances=token_distances, max_distance_mask=self.max_distance_mask, num_labels_per_sample=num_predict, use_pointer_network=use_pointer_network, num_sub_tokens=num_sub_tokens) else: self.dataset_train = CTLanguageModelingDataset(self.data_manager, token_distances=token_distances, max_distance_mask=self.max_distance_mask, num_labels_per_sample=num_predict, use_pointer_network=use_pointer_network, num_sub_tokens=num_sub_tokens) self.use_validation = use_validation if self.use_validation: data_manager_validation = CTBufferedDataManager(DATA_PATH_STAGE_2, language, partition="valid", shuffle=True, infinite_loading=True, mini_dataset=mini_dataset, size_load_buffer=10000, filter_language=filter_language, dataset_imbalance=dataset_imbalance) if use_no_punctuation: self.dataset_validation = CTLanguageModelingDatasetNoPunctuation(data_manager_validation, token_distances=token_distances, max_distance_mask=self.max_distance_mask, num_labels_per_sample=num_predict, use_pointer_network=use_pointer_network, num_sub_tokens=num_sub_tokens) else: self.dataset_validation = CTLanguageModelingDataset(data_manager_validation, token_distances=token_distances, max_distance_mask=self.max_distance_mask, num_labels_per_sample=num_predict, use_pointer_network=use_pointer_network, num_sub_tokens=num_sub_tokens) self.dataset_validation_creator = \ lambda infinite_loading: self._create_validation_dataset(DATA_PATH_STAGE_2, language, use_no_punctuation, token_distances, infinite_loading, num_predict, use_pointer_network, filter_language, dataset_imbalance, num_sub_tokens)