class Preprocess1Container: def __init__(self): self._init_execution() self._init_preprocessing() self._init_data() self.logger = get_logger(__file__) random.seed(self.random_seed) self._setup_data_manager() self._setup_vocab_builder() self._setup_preprocessor() self._setup_data_loader() self._store_config() # ========================================================================= # Parameter initializations # ========================================================================= @ex.capture(prefix='execution') def _init_execution(self, num_processes, batch_size, save_every, random_seed): self.num_processes = num_processes self.batch_size = batch_size self.save_every = save_every self.random_seed = random_seed @ex.capture(prefix='preprocessing') def _init_preprocessing(self, hard_num_tokens_limit, allow_empty_methods, use_tokens_limiter, separate_label_vocabulary): self.hard_num_tokens_limit = hard_num_tokens_limit self.allow_empty_methods = allow_empty_methods self.use_tokens_limiter = use_tokens_limiter self.separate_label_vocabulary = separate_label_vocabulary @ex.capture(prefix='data') def _init_data(self, language, partition): self.partition = partition if language in {'java-small', 'java-medium', 'java-large'}: self.dataset_name = language self.language = "java" self.dataset_type = "code2seq" else: self.dataset_name = language self.dataset_type = "code-search-net" self.language = language if self.dataset_type == 'code2seq': self.input_data_path = CODE2SEQ_EXTRACTED_METHODS_DATA_PATH else: self.input_data_path = CSN_RAW_DATA_PATH @ex.capture def _store_config(self, _config): self.data_manager.save_config(_config) # ========================================================================= # Setup Helper methods # ========================================================================= def _setup_data_manager(self): self.data_manager = CTBufferedDataManager(DATA_PATH_STAGE_1, self.dataset_name, self.partition) def _setup_vocab_builder(self): if self.partition == 'train': word_counter = WordCounter() token_type_counter = WordCounter() node_type_counter = WordCounter() if self.separate_label_vocabulary: word_counter_labels = WordCounter() self.word_counters = (word_counter, token_type_counter, node_type_counter, word_counter_labels) self.vocab_builder = CodeSummarizationVocabularyBuilder( *self.word_counters) else: self.word_counters = (word_counter, token_type_counter, node_type_counter) self.vocab_builder = VocabularyBuilder(*self.word_counters) def _setup_preprocessor(self): self.preprocessor = CTStage1Preprocessor( self.language, allow_empty_methods=self.allow_empty_methods, use_tokens_limiter=self.use_tokens_limiter, max_num_tokens=self.hard_num_tokens_limit) def _setup_data_loader(self): self.logger.info( f"loading dataset {self.dataset_name} for {self.language}...") if self.dataset_type == 'code2seq': self.dataloader = C2SRawDataLoader(self.input_data_path) self.dataloader.load_dataset(self.dataset_name, partition=self.partition) else: self.dataloader = CSNRawDataLoader(self.input_data_path) self.dataloader.load_all_for(self.language, partition=self.partition) self.n_raw_samples = len(self.dataloader) self.logger.info(f"Loaded {self.n_raw_samples} snippets") # ========================================================================= # Processing Helper Methods # ========================================================================= @staticmethod def _process_batch(preprocessor, x): i, batch = x try: return preprocessor.process(batch, i) except PreprocessingException as e: # Cannot use logger in parallel worker, as loggers cannot be pickled print(str(e)) # This is an expected exception, thus we just return an empty list, such that preprocessing can go on return [] except Exception as e: print(f"Error processing batch {i}:") func_names, docstrings, code_snippets = zip(*batch) print(str(e)) for snippet in code_snippets: print(snippet) traceback.print_exc() return [] def _save_dataset(self, dataset): # Building vocabulary before saving. Ensures that it is run in main process => no race conditions when updating # vocabulary if self.partition == 'train': for sample in dataset: self.vocab_builder(sample) dataset = [sample.compress() for sample in dataset] if dataset: self.logger.info( f"saving dataset batch with {len(dataset)} samples ...") self.data_manager.save(dataset) def _handle_shutdown(self, sig=None, frame=None): self.data_manager.shutdown() sys.exit(0) # ========================================================================= # Main method # ========================================================================= def run(self): os.umask(0o007) # Ensure graceful shutdown when preprocessing is interrupted signal.signal(signal.SIGINT, self._handle_shutdown) n_processed_samples = 0 with parallel_backend("loky") as parallel_config: execute_parallel = Parallel(self.num_processes, verbose=0) batched_samples_generator = enumerate( self.dataloader.read(self.batch_size, shuffle=True)) while True: self.logger.info("start processing batch ...") dataset_slice = itertools.islice( batched_samples_generator, int(self.save_every / self.batch_size)) with Timing() as t: dataset = execute_parallel( delayed(self._process_batch)(self.preprocessor, batch) for batch in dataset_slice) if dataset: dataset = [ sample for batch in dataset for sample in batch ] # List[batches] -> List[samples] self.logger.info( f"processing {len(dataset)} samples took {t[0]:0.2f} seconds ({t[0] / len(dataset):0.3f} seconds per " f"sample)") self._save_dataset(dataset) n_processed_samples += len(dataset) else: break if self.partition == 'train': self.data_manager.save_word_counters(*self.word_counters) self.logger.info("PREPROCESS-1 DONE!") self.logger.info( f"Successfully processed {n_processed_samples}/{self.n_raw_samples} samples ({n_processed_samples / self.n_raw_samples:0.2%})" ) self._handle_shutdown()
class Preprocess2Container: def __init__(self): self._init_execution() self._init_preprocessing() self._init_distances() self._init_binning() self._init_data() self.logger = get_logger(__file__) self._setup_data_managers() self._setup_vocabularies() self._setup_vocabulary_transformer() self._setup_distances_transformer() self._store_config() @ex.capture(prefix="execution") def _init_execution(self, num_processes, batch_size, dataset_slice_size): self.num_processes = num_processes self.batch_size = batch_size self.dataset_slice_size = dataset_slice_size @ex.capture(prefix="preprocessing") def _init_preprocessing(self, remove_punctuation, max_num_tokens, vocab_size, min_vocabulary_frequency, separate_label_vocabulary, vocab_size_labels, min_vocabulary_frequency_labels): self.remove_punctuation = remove_punctuation self.max_num_tokens = max_num_tokens self.vocab_size = vocab_size self.min_vocabulary_frequency = min_vocabulary_frequency self.separate_label_vocabulary = separate_label_vocabulary self.vocab_size_labels = vocab_size_labels self.min_vocabulary_frequency_labels = min_vocabulary_frequency_labels @ex.capture(prefix="distances") def _init_distances(self, ppr_alpha, ppr_use_log, ppr_threshold, sp_threshold, ancestor_sp_forward, ancestor_sp_backward, ancestor_sp_negative_reverse_dists, ancestor_sp_threshold, sibling_sp_forward, sibling_sp_backward, sibling_sp_negative_reverse_dists, sibling_sp_threshold): self.ppr_alpha = ppr_alpha self.ppr_use_log = ppr_use_log self.ppr_threshold = ppr_threshold self.sp_threshold = sp_threshold self.ancestor_sp_forward = ancestor_sp_forward self.ancestor_sp_backward = ancestor_sp_backward self.ancestor_sp_negative_reverse_dists = ancestor_sp_negative_reverse_dists self.ancestor_sp_threshold = ancestor_sp_threshold self.sibling_sp_forward = sibling_sp_forward self.sibling_sp_backward = sibling_sp_backward self.sibling_sp_negative_reverse_dists = sibling_sp_negative_reverse_dists self.sibling_sp_threshold = sibling_sp_threshold @ex.capture(prefix='data') def _init_data(self, language, partition): self.language = language self.partition = partition self.use_multi_language = ',' in self.language if self.use_multi_language: self.languages = self.language.split(',') self.input_path = DATA_PATH_STAGE_1 self.output_path = DATA_PATH_STAGE_2 @ex.capture(prefix="binning") def _init_binning(self, num_bins, n_fixed_bins, exponential_binning, exponential_binning_growth_factor, bin_padding): self.num_bins = num_bins self.n_fixed_bins = n_fixed_bins self.exponential_binning = exponential_binning self.exponential_binning_growth_factor = exponential_binning_growth_factor self.bin_padding = bin_padding @ex.capture def _store_config(self, _config): config = deepcopy(_config) config['preprocessing']['special_symbols'] = SPECIAL_SYMBOLS config['preprocessing'][ 'special_symbols_node_token_types'] = SPECIAL_SYMBOLS_NODE_TOKEN_TYPES self.output_data_manager.save_config(config) # ========================================================================= # Setup Helper Methods # ========================================================================= def _setup_data_managers(self): if self.use_multi_language: self.input_data_managers = [ CTBufferedDataManager(self.input_path, l, self.partition) for l in self.languages ] word_counters = zip(*[ input_data_manager.load_word_counters() for input_data_manager in self.input_data_managers ]) word_counters = [ self._combine_counters( word_counter, self.min_vocabulary_frequency) if i < 3 or not self.separate_label_vocabulary else self._combine_counters( word_counter, self.min_vocabulary_frequency_labels) for i, word_counter in enumerate(word_counters) ] self.input_data_manager = CombinedDataManager( self.input_data_managers, self.languages) else: self.input_data_manager = CTBufferedDataManager( self.input_path, self.language, self.partition) word_counters = self.input_data_manager.load_word_counters() if self.separate_label_vocabulary: self.word_counter, self.token_type_counter, self.node_type_counter, self.word_counter_labels = word_counters else: self.word_counter, self.token_type_counter, self.node_type_counter = word_counters self.output_data_manager = CTBufferedDataManager( self.output_path, self.language, self.partition) def _setup_vocabularies(self): if self.partition == "train": # Only build vocabularies on train data self.word_vocab = self.word_counter.to_vocabulary( limit_most_common=self.vocab_size, min_frequency=self.min_vocabulary_frequency, special_symbols=SPECIAL_SYMBOLS) self.token_type_vocab = self.token_type_counter.to_vocabulary( special_symbols=SPECIAL_SYMBOLS_NODE_TOKEN_TYPES) self.node_type_vocab = self.node_type_counter.to_vocabulary( special_symbols=SPECIAL_SYMBOLS_NODE_TOKEN_TYPES) if self.separate_label_vocabulary: self.word_vocab_labels = self.word_counter_labels.to_vocabulary( limit_most_common=self.vocab_size_labels, min_frequency=self.min_vocabulary_frequency_labels, special_symbols=SPECIAL_SYMBOLS) else: # On valid and test set, use already built vocabulary from train run if self.separate_label_vocabulary: self.word_vocab, self.token_type_vocab, self.node_type_vocab, self.word_vocab_labels = self.output_data_manager.load_vocabularies( ) else: self.word_vocab, self.token_type_vocab, self.node_type_vocab = self.output_data_manager.load_vocabularies( ) def _setup_vocabulary_transformer(self): if self.separate_label_vocabulary: self.vocabulary_transformer = CodeSummarizationVocabularyTransformer( self.word_vocab, self.token_type_vocab, self.node_type_vocab, self.word_vocab_labels) else: self.vocabulary_transformer = VocabularyTransformer( self.word_vocab, self.token_type_vocab, self.node_type_vocab) def _setup_distances_transformer(self): distance_metrics = [ PersonalizedPageRank(threshold=self.ppr_threshold, log=self.ppr_use_log, alpha=self.ppr_alpha), ShortestPaths(threshold=self.sp_threshold), AncestorShortestPaths( forward=self.ancestor_sp_forward, backward=self.ancestor_sp_backward, negative_reverse_dists=self.ancestor_sp_negative_reverse_dists, threshold=self.ancestor_sp_threshold), SiblingShortestPaths( forward=self.sibling_sp_forward, backward=self.sibling_sp_backward, negative_reverse_dists=self.sibling_sp_negative_reverse_dists, threshold=self.sibling_sp_threshold) ] if self.exponential_binning: db = DistanceBinning( self.num_bins, self.n_fixed_bins, ExponentialBinning(self.exponential_binning_growth_factor)) else: db = DistanceBinning(self.num_bins, self.n_fixed_bins) self.distances_transformer = DistancesTransformer(distance_metrics, db) def _combine_counters(self, counters, min_vocab_frequency): """ If multiple languages are used, we need to combine the word counts for all of them. Additionally, if MIN_VOCABULARY_FREQUENCY is set, we only allow a token if if passed the threshold in ANY of the languages """ combined_counter = WordCounter() # Dataframe with token x language -> count df = pd.DataFrame.from_dict({ language: counter.words for language, counter in zip(self.languages, counters) }) df = df.fillna(0) if min_vocab_frequency is not None: idx_frequent_words = (df > min_vocab_frequency).any(axis=1) df = df[idx_frequent_words] df = df.sum(axis=1) combined_counter.words = df.to_dict() return combined_counter @staticmethod def _process_batch(x, vocabulary_transformer, distances_transformer, use_multi_language, max_num_tokens, remove_punctuation): i, batch = x try: output = [] for sample in batch: if use_multi_language: sample_language, sample = sample assert len(sample) == 6, f"Unexpected sample format! {sample}" sample = CTStage1Sample.from_compressed(sample) if max_num_tokens is not None and len( sample.tokens) > max_num_tokens: print( f"Snippet with {len(sample.tokens)} tokens exceeds limit of {max_num_tokens}! Skipping" ) continue if remove_punctuation: sample.remove_punctuation() sample = vocabulary_transformer(sample) sample = distances_transformer(sample) if use_multi_language: sample = CTStage2MultiLanguageSample( sample.tokens, sample.graph_sample, sample.token_mapping, sample.stripped_code_snippet, sample.func_name, sample.docstring, sample_language, encoded_func_name=sample.encoded_func_name if hasattr( sample, 'encoded_func_name') else None) output.append(sample) return output except Exception as e: # Cannot use logger in parallel worker, as loggers cannot be pickled print(str(e)) traceback.print_exc() return [] def _handle_shutdown(self, sig=None, frame=None): if self.use_multi_language: for idm in self.input_data_managers: idm.shutdown() else: self.input_data_manager.shutdown() self.output_data_manager.shutdown() sys.exit(0) def run(self): # ----------------------------------------------------------------------------- # Multiprocessing Loop # ----------------------------------------------------------------------------- self.logger.info("Start processing...") # Ensure graceful shutdown when preprocessing is interrupted signal.signal(signal.SIGINT, self._handle_shutdown) n_samples_after = 0 with parallel_backend("loky") as parallel_config: execute_parallel = Parallel(self.num_processes, verbose=0) batched_samples_generator = enumerate( self.input_data_manager.read(self.batch_size)) while True: dataset_slice = itertools.islice( batched_samples_generator, int(self.dataset_slice_size / self.batch_size)) with Timing() as t: dataset = execute_parallel( delayed(self._process_batch) (batch, self.vocabulary_transformer, self.distances_transformer, self.use_multi_language, self.max_num_tokens, self.remove_punctuation) for batch in dataset_slice) if dataset: dataset = [sample for batch in dataset for sample in batch] n_samples_after += len(dataset) self.logger.info( f"processing {len(dataset)} samples took {t[0]:0.2f} seconds ({t[0] / len(dataset):0.3f} seconds " f"per sample)") self.output_data_manager.save(dataset) else: break self.logger.info("Saving vocabulary") if self.separate_label_vocabulary: self.output_data_manager.save_vocabularies(self.word_vocab, self.token_type_vocab, self.node_type_vocab, self.word_vocab_labels) else: self.output_data_manager.save_vocabularies(self.word_vocab, self.token_type_vocab, self.node_type_vocab) self.logger.info("PREPROCESS-2 DONE!") self.logger.info(f"Successfully processed {n_samples_after} samples") self._handle_shutdown()