elif args.model == 'xl_net': model_manager = XLNetModelManager() else: raise ValueError(f"Unknown model type `{args.model}`") model = model_manager.load_model(args.run_id, args.snapshot_iteration, gpu=not args.no_gpu) model = model.eval() if not args.no_gpu: model = model.cuda() config = model_manager.load_config(args.run_id) data_manager = CTBufferedDataManager(DATA_PATH_STAGE_2, config['data_setup']['language'], partition=args.partition, shuffle=False) vocabularies = data_manager.load_vocabularies() if len(vocabularies) == 3: word_vocab, _, _ = vocabularies else: word_vocab, _, _, _ = vocabularies token_distances = None if TokenDistancesTransform.name in config['data_transforms']['relative_distances']: num_bins = data_manager.load_config()['binning']['num_bins'] distance_binning_config = config['data_transforms']['distance_binning'] if distance_binning_config['type'] == 'exponential': trans_func = ExponentialBinning(distance_binning_config['growth_factor']) else: trans_func = EqualBinning() token_distances = TokenDistancesTransform( DistanceBinning(num_bins, distance_binning_config['n_fixed_bins'], trans_func))
class CTCodeSummarizationGreatMixin(ExperimentSetup, ABC): @ex.capture(prefix="data_setup") def _init_data(self, language, use_validation=False, mini_dataset=False, num_sub_tokens=5, num_subtokens_output=5, use_only_ast=False, sort_by_length=False, shuffle=True, use_pointer_network=False): self.num_sub_tokens = num_sub_tokens self.use_validation = use_validation self.use_pointer_network = use_pointer_network self.data_manager = CTBufferedDataManager( DATA_PATH_STAGE_2, language, shuffle=shuffle, infinite_loading=True, mini_dataset=mini_dataset, sort_by_length=sort_by_length) self.word_vocab, self.token_type_vocab, self.node_type_vocab = self.data_manager.load_vocabularies( ) if ',' in language: self.num_languages = len(language.split(',')) self.use_separate_vocab = False token_distances = None if TokenDistancesTransform.name in self.relative_distances: token_distances = TokenDistancesTransform( DistanceBinning( self.data_manager.load_config()['binning']['num_bins'], self.distance_binning['n_fixed_bins'], self.distance_binning['trans_func'])) self.dataset_train = CTCodeSummarizationDatasetEdgeTypes( self.data_manager, token_distances=token_distances, max_distance_mask=self.max_distance_mask, num_sub_tokens=num_sub_tokens, num_sub_tokens_output=num_subtokens_output, use_pointer_network=use_pointer_network) if self.use_validation: data_manager_validation = CTBufferedDataManager( DATA_PATH_STAGE_2, language, partition="valid", shuffle=True, infinite_loading=True, mini_dataset=mini_dataset) self.dataset_validation = CTCodeSummarizationDatasetEdgeTypes( data_manager_validation, token_distances=token_distances, max_distance_mask=self.max_distance_mask, num_sub_tokens=num_sub_tokens, num_sub_tokens_output=num_subtokens_output, use_pointer_network=use_pointer_network) self.dataset_validation_creator = \ lambda infinite_loading: self._create_validation_dataset(DATA_PATH_STAGE_2, language, token_distances, num_sub_tokens, num_subtokens_output, infinite_loading=infinite_loading, use_pointer_network=use_pointer_network, filter_language=None, dataset_imbalance=None) def _create_validation_dataset(self, data_location, language, token_distances, num_sub_tokens, num_subtokens_output, infinite_loading, use_pointer_network, filter_language, dataset_imbalance): data_manager_validation = CTBufferedDataManager( data_location, language, partition="valid", shuffle=True, infinite_loading=infinite_loading) dataset_validation = CTCodeSummarizationDatasetEdgeTypes( data_manager_validation, token_distances=token_distances, max_distance_mask=self.max_distance_mask, num_sub_tokens=num_sub_tokens, num_sub_tokens_output=num_subtokens_output, use_pointer_network=use_pointer_network) return dataset_validation
class Preprocess2Container: def __init__(self): self._init_execution() self._init_preprocessing() self._init_distances() self._init_binning() self._init_data() self.logger = get_logger(__file__) self._setup_data_managers() self._setup_vocabularies() self._setup_vocabulary_transformer() self._setup_distances_transformer() self._store_config() @ex.capture(prefix="execution") def _init_execution(self, num_processes, batch_size, dataset_slice_size): self.num_processes = num_processes self.batch_size = batch_size self.dataset_slice_size = dataset_slice_size @ex.capture(prefix="preprocessing") def _init_preprocessing(self, remove_punctuation, max_num_tokens, vocab_size, min_vocabulary_frequency, separate_label_vocabulary, vocab_size_labels, min_vocabulary_frequency_labels): self.remove_punctuation = remove_punctuation self.max_num_tokens = max_num_tokens self.vocab_size = vocab_size self.min_vocabulary_frequency = min_vocabulary_frequency self.separate_label_vocabulary = separate_label_vocabulary self.vocab_size_labels = vocab_size_labels self.min_vocabulary_frequency_labels = min_vocabulary_frequency_labels @ex.capture(prefix="distances") def _init_distances(self, ppr_alpha, ppr_use_log, ppr_threshold, sp_threshold, ancestor_sp_forward, ancestor_sp_backward, ancestor_sp_negative_reverse_dists, ancestor_sp_threshold, sibling_sp_forward, sibling_sp_backward, sibling_sp_negative_reverse_dists, sibling_sp_threshold): self.ppr_alpha = ppr_alpha self.ppr_use_log = ppr_use_log self.ppr_threshold = ppr_threshold self.sp_threshold = sp_threshold self.ancestor_sp_forward = ancestor_sp_forward self.ancestor_sp_backward = ancestor_sp_backward self.ancestor_sp_negative_reverse_dists = ancestor_sp_negative_reverse_dists self.ancestor_sp_threshold = ancestor_sp_threshold self.sibling_sp_forward = sibling_sp_forward self.sibling_sp_backward = sibling_sp_backward self.sibling_sp_negative_reverse_dists = sibling_sp_negative_reverse_dists self.sibling_sp_threshold = sibling_sp_threshold @ex.capture(prefix='data') def _init_data(self, language, partition): self.language = language self.partition = partition self.use_multi_language = ',' in self.language if self.use_multi_language: self.languages = self.language.split(',') self.input_path = DATA_PATH_STAGE_1 self.output_path = DATA_PATH_STAGE_2 @ex.capture(prefix="binning") def _init_binning(self, num_bins, n_fixed_bins, exponential_binning, exponential_binning_growth_factor, bin_padding): self.num_bins = num_bins self.n_fixed_bins = n_fixed_bins self.exponential_binning = exponential_binning self.exponential_binning_growth_factor = exponential_binning_growth_factor self.bin_padding = bin_padding @ex.capture def _store_config(self, _config): config = deepcopy(_config) config['preprocessing']['special_symbols'] = SPECIAL_SYMBOLS config['preprocessing'][ 'special_symbols_node_token_types'] = SPECIAL_SYMBOLS_NODE_TOKEN_TYPES self.output_data_manager.save_config(config) # ========================================================================= # Setup Helper Methods # ========================================================================= def _setup_data_managers(self): if self.use_multi_language: self.input_data_managers = [ CTBufferedDataManager(self.input_path, l, self.partition) for l in self.languages ] word_counters = zip(*[ input_data_manager.load_word_counters() for input_data_manager in self.input_data_managers ]) word_counters = [ self._combine_counters( word_counter, self.min_vocabulary_frequency) if i < 3 or not self.separate_label_vocabulary else self._combine_counters( word_counter, self.min_vocabulary_frequency_labels) for i, word_counter in enumerate(word_counters) ] self.input_data_manager = CombinedDataManager( self.input_data_managers, self.languages) else: self.input_data_manager = CTBufferedDataManager( self.input_path, self.language, self.partition) word_counters = self.input_data_manager.load_word_counters() if self.separate_label_vocabulary: self.word_counter, self.token_type_counter, self.node_type_counter, self.word_counter_labels = word_counters else: self.word_counter, self.token_type_counter, self.node_type_counter = word_counters self.output_data_manager = CTBufferedDataManager( self.output_path, self.language, self.partition) def _setup_vocabularies(self): if self.partition == "train": # Only build vocabularies on train data self.word_vocab = self.word_counter.to_vocabulary( limit_most_common=self.vocab_size, min_frequency=self.min_vocabulary_frequency, special_symbols=SPECIAL_SYMBOLS) self.token_type_vocab = self.token_type_counter.to_vocabulary( special_symbols=SPECIAL_SYMBOLS_NODE_TOKEN_TYPES) self.node_type_vocab = self.node_type_counter.to_vocabulary( special_symbols=SPECIAL_SYMBOLS_NODE_TOKEN_TYPES) if self.separate_label_vocabulary: self.word_vocab_labels = self.word_counter_labels.to_vocabulary( limit_most_common=self.vocab_size_labels, min_frequency=self.min_vocabulary_frequency_labels, special_symbols=SPECIAL_SYMBOLS) else: # On valid and test set, use already built vocabulary from train run if self.separate_label_vocabulary: self.word_vocab, self.token_type_vocab, self.node_type_vocab, self.word_vocab_labels = self.output_data_manager.load_vocabularies( ) else: self.word_vocab, self.token_type_vocab, self.node_type_vocab = self.output_data_manager.load_vocabularies( ) def _setup_vocabulary_transformer(self): if self.separate_label_vocabulary: self.vocabulary_transformer = CodeSummarizationVocabularyTransformer( self.word_vocab, self.token_type_vocab, self.node_type_vocab, self.word_vocab_labels) else: self.vocabulary_transformer = VocabularyTransformer( self.word_vocab, self.token_type_vocab, self.node_type_vocab) def _setup_distances_transformer(self): distance_metrics = [ PersonalizedPageRank(threshold=self.ppr_threshold, log=self.ppr_use_log, alpha=self.ppr_alpha), ShortestPaths(threshold=self.sp_threshold), AncestorShortestPaths( forward=self.ancestor_sp_forward, backward=self.ancestor_sp_backward, negative_reverse_dists=self.ancestor_sp_negative_reverse_dists, threshold=self.ancestor_sp_threshold), SiblingShortestPaths( forward=self.sibling_sp_forward, backward=self.sibling_sp_backward, negative_reverse_dists=self.sibling_sp_negative_reverse_dists, threshold=self.sibling_sp_threshold) ] if self.exponential_binning: db = DistanceBinning( self.num_bins, self.n_fixed_bins, ExponentialBinning(self.exponential_binning_growth_factor)) else: db = DistanceBinning(self.num_bins, self.n_fixed_bins) self.distances_transformer = DistancesTransformer(distance_metrics, db) def _combine_counters(self, counters, min_vocab_frequency): """ If multiple languages are used, we need to combine the word counts for all of them. Additionally, if MIN_VOCABULARY_FREQUENCY is set, we only allow a token if if passed the threshold in ANY of the languages """ combined_counter = WordCounter() # Dataframe with token x language -> count df = pd.DataFrame.from_dict({ language: counter.words for language, counter in zip(self.languages, counters) }) df = df.fillna(0) if min_vocab_frequency is not None: idx_frequent_words = (df > min_vocab_frequency).any(axis=1) df = df[idx_frequent_words] df = df.sum(axis=1) combined_counter.words = df.to_dict() return combined_counter @staticmethod def _process_batch(x, vocabulary_transformer, distances_transformer, use_multi_language, max_num_tokens, remove_punctuation): i, batch = x try: output = [] for sample in batch: if use_multi_language: sample_language, sample = sample assert len(sample) == 6, f"Unexpected sample format! {sample}" sample = CTStage1Sample.from_compressed(sample) if max_num_tokens is not None and len( sample.tokens) > max_num_tokens: print( f"Snippet with {len(sample.tokens)} tokens exceeds limit of {max_num_tokens}! Skipping" ) continue if remove_punctuation: sample.remove_punctuation() sample = vocabulary_transformer(sample) sample = distances_transformer(sample) if use_multi_language: sample = CTStage2MultiLanguageSample( sample.tokens, sample.graph_sample, sample.token_mapping, sample.stripped_code_snippet, sample.func_name, sample.docstring, sample_language, encoded_func_name=sample.encoded_func_name if hasattr( sample, 'encoded_func_name') else None) output.append(sample) return output except Exception as e: # Cannot use logger in parallel worker, as loggers cannot be pickled print(str(e)) traceback.print_exc() return [] def _handle_shutdown(self, sig=None, frame=None): if self.use_multi_language: for idm in self.input_data_managers: idm.shutdown() else: self.input_data_manager.shutdown() self.output_data_manager.shutdown() sys.exit(0) def run(self): # ----------------------------------------------------------------------------- # Multiprocessing Loop # ----------------------------------------------------------------------------- self.logger.info("Start processing...") # Ensure graceful shutdown when preprocessing is interrupted signal.signal(signal.SIGINT, self._handle_shutdown) n_samples_after = 0 with parallel_backend("loky") as parallel_config: execute_parallel = Parallel(self.num_processes, verbose=0) batched_samples_generator = enumerate( self.input_data_manager.read(self.batch_size)) while True: dataset_slice = itertools.islice( batched_samples_generator, int(self.dataset_slice_size / self.batch_size)) with Timing() as t: dataset = execute_parallel( delayed(self._process_batch) (batch, self.vocabulary_transformer, self.distances_transformer, self.use_multi_language, self.max_num_tokens, self.remove_punctuation) for batch in dataset_slice) if dataset: dataset = [sample for batch in dataset for sample in batch] n_samples_after += len(dataset) self.logger.info( f"processing {len(dataset)} samples took {t[0]:0.2f} seconds ({t[0] / len(dataset):0.3f} seconds " f"per sample)") self.output_data_manager.save(dataset) else: break self.logger.info("Saving vocabulary") if self.separate_label_vocabulary: self.output_data_manager.save_vocabularies(self.word_vocab, self.token_type_vocab, self.node_type_vocab, self.word_vocab_labels) else: self.output_data_manager.save_vocabularies(self.word_vocab, self.token_type_vocab, self.node_type_vocab) self.logger.info("PREPROCESS-2 DONE!") self.logger.info(f"Successfully processed {n_samples_after} samples") self._handle_shutdown()
class CTCodeSummarizationMixin(ExperimentSetup, ABC): @ex.capture(prefix="data_setup") def _init_data(self, language, use_validation=False, mini_dataset=False, num_sub_tokens=NUM_SUB_TOKENS, num_subtokens_output=NUM_SUB_TOKENS_METHOD_NAME, use_only_ast=False, use_no_punctuation=False, use_pointer_network=False, sort_by_length=False, chunk_size=None, filter_language=None, dataset_imbalance=None, mask_all_tokens=False): self.data_manager = CTBufferedDataManager( DATA_PATH_STAGE_2, language, shuffle=True, infinite_loading=True, mini_dataset=mini_dataset, sort_by_length=sort_by_length, chunk_size=chunk_size, filter_language=filter_language, dataset_imbalance=dataset_imbalance) vocabs = self.data_manager.load_vocabularies() if len(vocabs) == 4: self.word_vocab, self.token_type_vocab, self.node_type_vocab, self.method_name_vocab = vocabs self.use_separate_vocab = True else: self.word_vocab, self.token_type_vocab, self.node_type_vocab = vocabs self.use_separate_vocab = False if ',' in language: self.num_languages = len(language.split(',')) token_distances = None if TokenDistancesTransform.name in self.relative_distances: print('Token distances will be calculated in dataset.') num_bins = self.data_manager.load_config()['binning']['num_bins'] token_distances = TokenDistancesTransform( DistanceBinning(num_bins, self.distance_binning['n_fixed_bins'], self.distance_binning['trans_func'])) self.use_only_ast = use_only_ast self.use_pointer_network = use_pointer_network self.num_sub_tokens = num_sub_tokens if use_only_ast: self.dataset_train = CTCodeSummarizationOnlyASTDataset( self.data_manager, token_distances=token_distances, max_distance_mask=self.max_distance_mask, num_sub_tokens=num_sub_tokens, num_sub_tokens_output=num_subtokens_output, use_pointer_network=use_pointer_network, mask_all_tokens=mask_all_tokens) elif use_no_punctuation: self.dataset_train = CTCodeSummarizationDatasetNoPunctuation( self.data_manager, token_distances=token_distances, max_distance_mask=self.max_distance_mask, num_sub_tokens=num_sub_tokens, num_sub_tokens_output=num_subtokens_output, use_pointer_network=use_pointer_network) else: self.dataset_train = CTCodeSummarizationDataset( self.data_manager, token_distances=token_distances, max_distance_mask=self.max_distance_mask, num_sub_tokens=num_sub_tokens, num_sub_tokens_output=num_subtokens_output, use_pointer_network=use_pointer_network) self.use_validation = use_validation if self.use_validation: data_manager_validation = CTBufferedDataManager( DATA_PATH_STAGE_2, language, partition="valid", shuffle=True, infinite_loading=True, mini_dataset=mini_dataset, filter_language=filter_language, dataset_imbalance=dataset_imbalance) if use_only_ast: self.dataset_validation = CTCodeSummarizationOnlyASTDataset( data_manager_validation, token_distances=token_distances, max_distance_mask=self.max_distance_mask, num_sub_tokens=num_sub_tokens, num_sub_tokens_output=num_subtokens_output, use_pointer_network=use_pointer_network, mask_all_tokens=mask_all_tokens) elif use_no_punctuation: self.dataset_validation = CTCodeSummarizationDatasetNoPunctuation( data_manager_validation, token_distances=token_distances, max_distance_mask=self.max_distance_mask, num_sub_tokens=num_sub_tokens, num_sub_tokens_output=num_subtokens_output, use_pointer_network=use_pointer_network) else: self.dataset_validation = CTCodeSummarizationDataset( data_manager_validation, token_distances=token_distances, max_distance_mask=self.max_distance_mask, num_sub_tokens=num_sub_tokens, num_sub_tokens_output=num_subtokens_output, use_pointer_network=use_pointer_network) self.dataset_validation_creator = \ lambda infinite_loading: self._create_validation_dataset(DATA_PATH_STAGE_2, language, use_only_ast, use_no_punctuation, token_distances, num_sub_tokens, num_subtokens_output, infinite_loading, use_pointer_network, filter_language, dataset_imbalance, mask_all_tokens) def _create_validation_dataset(self, data_location, language, use_only_ast, use_no_punctuation, token_distances, num_sub_tokens, num_subtokens_output, infinite_loading, use_pointer_network, filter_language, dataset_imbalance, mask_all_tokens): data_manager_validation = CTBufferedDataManager( data_location, language, partition="valid", shuffle=True, infinite_loading=infinite_loading, filter_language=filter_language, dataset_imbalance=dataset_imbalance) if use_only_ast: dataset_validation = CTCodeSummarizationOnlyASTDataset( data_manager_validation, token_distances=token_distances, max_distance_mask=self.max_distance_mask, num_sub_tokens=num_sub_tokens, num_sub_tokens_output=num_subtokens_output, use_pointer_network=use_pointer_network, mask_all_tokens=mask_all_tokens) elif use_no_punctuation: dataset_validation = CTCodeSummarizationDatasetNoPunctuation( data_manager_validation, token_distances=token_distances, max_distance_mask=self.max_distance_mask, num_sub_tokens=num_sub_tokens, num_sub_tokens_output=num_subtokens_output, use_pointer_network=use_pointer_network) else: dataset_validation = CTCodeSummarizationDataset( data_manager_validation, token_distances=token_distances, max_distance_mask=self.max_distance_mask, num_sub_tokens=num_sub_tokens, num_sub_tokens_output=num_subtokens_output, use_pointer_network=use_pointer_network) return dataset_validation
class ExperimentSetup: def __init__(self): self._init_config() self._init_data_transforms() self._init_data() self._init_transfer_learning() self._init_model() self._init_optimizer() @ex.capture def _init_config(self, _config): self.config = _config @ex.capture(prefix="data_transforms") def _init_data_transforms(self, max_distance_mask, relative_distances, distance_binning): self.max_distance_mask = None if max_distance_mask is None else MaxDistanceMaskTransform(max_distance_mask) self.relative_distances = [] if relative_distances is None else relative_distances if distance_binning['type'] == 'exponential': trans_func = ExponentialBinning(distance_binning['growth_factor']) else: trans_func = EqualBinning() self.distance_binning = { 'n_fixed_bins': distance_binning['n_fixed_bins'], 'trans_func': trans_func } @ex.capture(prefix="data_setup") def _init_data(self, language, num_predict, use_validation=False, mini_dataset=False, use_no_punctuation=False, use_pointer_network=False, sort_by_length=False, shuffle=True, chunk_size=None, filter_language=None, dataset_imbalance=None, num_sub_tokens=NUM_SUB_TOKENS): self.data_manager = CTBufferedDataManager(DATA_PATH_STAGE_2, language, shuffle=shuffle, infinite_loading=True, mini_dataset=mini_dataset, size_load_buffer=10000, sort_by_length=sort_by_length, chunk_size=chunk_size, filter_language=filter_language, dataset_imbalance=dataset_imbalance) self.word_vocab, self.token_type_vocab, self.node_type_vocab = self.data_manager.load_vocabularies() token_distances = None if TokenDistancesTransform.name in self.relative_distances: num_bins = self.data_manager.load_config()['binning']['num_bins'] token_distances = TokenDistancesTransform( DistanceBinning(num_bins, self.distance_binning['n_fixed_bins'], self.distance_binning['trans_func'])) self.num_predict = num_predict self.use_pointer_network = use_pointer_network self.use_separate_vocab = False # For language modeling we always only operate on the method body vocabulary if use_no_punctuation: self.dataset_train = CTLanguageModelingDatasetNoPunctuation(self.data_manager, token_distances=token_distances, max_distance_mask=self.max_distance_mask, num_labels_per_sample=num_predict, use_pointer_network=use_pointer_network, num_sub_tokens=num_sub_tokens) else: self.dataset_train = CTLanguageModelingDataset(self.data_manager, token_distances=token_distances, max_distance_mask=self.max_distance_mask, num_labels_per_sample=num_predict, use_pointer_network=use_pointer_network, num_sub_tokens=num_sub_tokens) self.use_validation = use_validation if self.use_validation: data_manager_validation = CTBufferedDataManager(DATA_PATH_STAGE_2, language, partition="valid", shuffle=True, infinite_loading=True, mini_dataset=mini_dataset, size_load_buffer=10000, filter_language=filter_language, dataset_imbalance=dataset_imbalance) if use_no_punctuation: self.dataset_validation = CTLanguageModelingDatasetNoPunctuation(data_manager_validation, token_distances=token_distances, max_distance_mask=self.max_distance_mask, num_labels_per_sample=num_predict, use_pointer_network=use_pointer_network, num_sub_tokens=num_sub_tokens) else: self.dataset_validation = CTLanguageModelingDataset(data_manager_validation, token_distances=token_distances, max_distance_mask=self.max_distance_mask, num_labels_per_sample=num_predict, use_pointer_network=use_pointer_network, num_sub_tokens=num_sub_tokens) self.dataset_validation_creator = \ lambda infinite_loading: self._create_validation_dataset(DATA_PATH_STAGE_2, language, use_no_punctuation, token_distances, infinite_loading, num_predict, use_pointer_network, filter_language, dataset_imbalance, num_sub_tokens) def _create_validation_dataset(self, data_location, language, use_no_punctuation, token_distances, infinite_loading, num_predict, use_pointer_network, filter_language, dataset_imbalance, num_sub_tokens): data_manager_validation = CTBufferedDataManager(data_location, language, partition="valid", shuffle=True, infinite_loading=infinite_loading, size_load_buffer=10000, filter_language=filter_language, dataset_imbalance=dataset_imbalance) if use_no_punctuation: return CTLanguageModelingDatasetNoPunctuation(data_manager_validation, token_distances=token_distances, max_distance_mask=self.max_distance_mask, num_labels_per_sample=num_predict, use_pointer_network=use_pointer_network, num_sub_tokens=num_sub_tokens) else: return CTLanguageModelingDataset(data_manager_validation, token_distances=token_distances, max_distance_mask=self.max_distance_mask, num_labels_per_sample=num_predict, use_pointer_network=use_pointer_network, num_sub_tokens=num_sub_tokens) @ex.capture(prefix="transfer_learning") def _init_transfer_learning(self, use_pretrained_model=False, model_type=None, run_id=None, snapshot_iteration=None, cpu=False, freeze_encoder_layers=None): assert not use_pretrained_model or ( run_id is not None and snapshot_iteration is not None and model_type is not None), "model_type, run_id and snapshot_iteration have to be provided if " \ "use_pretrained_model is set" self.use_pretrained_model = use_pretrained_model if use_pretrained_model: print( f"Using Transfer Learning. Loading snapshot snapshot-{snapshot_iteration} from run {run_id} in collection " f"{model_type} ") if model_type == 'ct_code_summarization': model_manager = CodeTransformerModelManager() pretrained_model = model_manager.load_model(run_id, snapshot_iteration, gpu=not cpu) self.pretrained_model = pretrained_model elif model_type == 'ct_lm': model_manager = CodeTransformerLMModelManager() pretrained_model = model_manager.load_model(run_id, snapshot_iteration, gpu=not cpu) self.pretrained_model = pretrained_model else: model_manager = ModelManager(MODELS_SAVE_PATH, model_type) self.pretrained_model_params = model_manager.load_parameters(run_id, snapshot_iteration, gpu=not cpu) encoder_config = model_manager.load_config(run_id)['model']['transformer_lm_encoder'] self.pretrained_transformer_encoder_config = TransformerLMEncoderConfig(**encoder_config) if freeze_encoder_layers is not None: self.freeze_encoder_layers = freeze_encoder_layers def generate_transformer_lm_encoder_config(self, transformer_lm_encoder: dict) -> TransformerLMEncoderConfig: config = TransformerLMEncoderConfig(**transformer_lm_encoder) if self.use_pretrained_model: loaded_config = self.pretrained_transformer_encoder_config if not config == self.pretrained_transformer_encoder_config: print(f"pretrained configuration differs from given configuration. Pretrained: " f"{self.pretrained_transformer_encoder_config}, Given: {config}. Try merging...") loaded_config.input_nonlinearity = config.input_nonlinearity loaded_config.transformer['encoder_layer']['dropout'] = config.transformer['encoder_layer']['dropout'] loaded_config.transformer['encoder_layer']['activation'] \ = config.transformer['encoder_layer']['activation'] config = loaded_config transformer_config = dict(config.transformer) if hasattr(self, "word_vocab"): config.vocab_size = len(self.word_vocab) if hasattr(self, "token_type_vocab"): if hasattr(self, "use_only_ast") and self.use_only_ast: config.num_token_types = None else: config.num_token_types = len(self.token_type_vocab) if hasattr(self, "node_type_vocab"): config.num_node_types = len(self.node_type_vocab) if hasattr(self, "relative_distances"): encoder_layer_config = dict(transformer_config['encoder_layer']) encoder_layer_config['num_relative_distances'] = len(self.relative_distances) transformer_config['encoder_layer'] = encoder_layer_config if hasattr(self, "num_sub_tokens"): config.subtokens_per_token = self.num_sub_tokens if hasattr(self, 'num_languages'): config.num_languages = self.num_languages config.transformer = transformer_config return config @abstractmethod def _init_model(self, *args, **kwargs): self.model_lm = None self.with_cuda = True self.model_manager = None @ex.capture(prefix="optimizer") def _init_optimizer(self, learning_rate, reg_scale, scheduler=None, scheduler_params=None, optimizer="Adam"): if optimizer == 'Adam': self.optimizer = optim.Adam(self.model_lm.parameters(), lr=learning_rate, weight_decay=reg_scale) elif optimizer == 'Momentum': self.optimizer = optim.SGD(self.model_lm.parameters(), lr=learning_rate, weight_decay=reg_scale, momentum=0.95, nesterov=True) self.scheduler = None if scheduler == 'OneCycleLR': self.scheduler = optim.lr_scheduler.OneCycleLR(self.optimizer, **scheduler_params) elif scheduler == 'MultiStepLR': self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, **scheduler_params) def _init_metrics(self, metrics): self.metrics = dict() pad_id = self.word_vocab[PAD_TOKEN] unk_id = self.word_vocab[UNKNOWN_TOKEN] for metric in metrics: if metric == 'top1_accuracy': self.metrics[metric] = top1_accuracy self.metrics[f"{metric}_no_unk"] = lambda logits, labels: top1_accuracy(logits, labels, unk_id=unk_id, pad_id=pad_id) elif metric == 'top5_accuracy': self.metrics[metric] = lambda logits, labels: topk_accuracy(5, logits, labels) self.metrics[f"{metric}_no_unk"] = lambda logits, labels: topk_accuracy(5, logits, labels, unk_id=unk_id, pad_id=pad_id) elif metric == 'precision': self.metrics[metric] = lambda logits, labels: precision(logits, labels, pad_id=pad_id) self.metrics[f"{metric}_no_unk"] = lambda logits, labels: precision(logits, labels, pad_id=pad_id, unk_id=unk_id) elif metric == 'recall': self.metrics[metric] = lambda logits, labels: recall(logits, labels, pad_id=pad_id) self.metrics[f"{metric}_no_unk"] = lambda logits, labels: recall(logits, labels, pad_id=pad_id, unk_id=unk_id) elif metric == 'f1_score': self.metrics[metric] = lambda logits, labels: f1_score(logits, labels, pad_id=pad_id) self.metrics[f"{metric}_no_unk"] = lambda logits, labels: f1_score(logits, labels, pad_id=pad_id, unk_id=unk_id) elif metric == 'non_trivial_accuracy': self.metrics[metric] = lambda logits, labels: non_trivial_words_accuracy(logits, labels, pad_id) self.metrics[f"{metric}_no_unk"] = lambda logits, labels: non_trivial_words_accuracy(logits, labels, pad_id, unk_id=unk_id) elif metric == 'micro_f1_score': self.metrics[metric] = lambda logits, labels: micro_f1_score(logits, labels, pad_id=pad_id, unk_id=unk_id) elif metric == 'rouge_2': self.metrics[metric] = lambda logits, labels: rouge_2(logits, labels, pad_id=pad_id) elif metric == 'rouge_l': self.metrics[metric] = lambda logits, labels: rouge_l(logits, labels, pad_id=pad_id) @ex.capture(prefix="training") def train(self, batch_size, simulated_batch_size, random_seed, metrics, validate_every=None, persistent_snapshot_every=None, simulated_batch_size_valid=None, early_stopping_patience=10, max_validation_samples=10000, accumulate_tokens_batch=False): if self.with_cuda: self.model_lm = self.model_lm.cuda() self.device = "cuda" else: self.device = "cpu" run_id = self.model_manager.generate_run_name() self.logger = ExperimentLogger("experiment", TensorboardLogger(f"{LOGS_PATH}/{self.model_manager.model_type}/{run_id}")) self.logger.info(f"===============================================") self.logger.info(f"Starting run {run_id}") self.logger.info(f"===============================================") self.model_manager.save_config(run_id, self.config) early_stopping = EarlyStopping(self.model_manager, run_id, early_stopping_patience) num_params = sum([len(params.view(-1)) for params in self.model_lm.parameters()]) self.logger.info(f"Start training model with {num_params} parameters") self.logger.info(f"Model setup: {self.model_lm}") self._init_metrics(metrics) torch.manual_seed(random_seed) random.seed(random_seed) # Simulated batches simulated_batch_size = batch_size if simulated_batch_size is None else simulated_batch_size assert simulated_batch_size % batch_size == 0, "simulated_batch_size must be a multiple of batch_size" num_simulated_batches = simulated_batch_size // batch_size # Main train loop train_step = 0 dataloader = DataLoader(self.dataset_train, batch_size=batch_size, collate_fn=self.dataset_train.collate_fn) if self.use_validation: if simulated_batch_size_valid is None: simulated_batch_size_valid = simulated_batch_size num_simulated_batches_valid = simulated_batch_size_valid // batch_size dataloader_validation = iter(DataLoader(self.dataset_validation, batch_size=batch_size, collate_fn=self.dataset_validation.collate_fn)) n_tokens_accumulate_batch = None if accumulate_tokens_batch: n_tokens_accumulate_batch = 0 epoch = 1 progress_bar = tqdm(total=int(self.data_manager.approximate_total_samples() / batch_size)) progress_bar.set_description(f"Epoch {epoch}") # Ensure graceful shutdown when training is interrupted signal.signal(signal.SIGINT, self._handle_shutdown) with Timing() as t: for it, batch in enumerate(dataloader): self.logger.log_time(t.measure() / batch_size, "dataloader_seconds/sample", train_step * simulated_batch_size + (it % num_simulated_batches) * batch_size) # Calculate gradients batch = batch_filter_distances(batch, self.relative_distances) model_out = self._train_step(batch, num_simulated_batches) self.logger.log_time(t.measure() / batch_size, "model_seconds/sample", train_step * simulated_batch_size + (it % num_simulated_batches) * batch_size) # Log actual predicted words and labels self.logger.log_text("input/train", str([[self.word_vocab.reverse_lookup(st.item()) for st in token if st.item() != self.word_vocab[PAD_TOKEN] and st.item() != self.word_vocab[EOS_TOKEN]] for token in batch.tokens[0]])) self.logger.log_text("predicted words/train", str(self._decode_predicted_words(model_out, batch))) self.logger.log_text("labels/train", str(self._decode_labels(batch))) # Calculate metrics evaluation = self._evaluate_predictions(model_out.logits, batch.labels, loss=model_out.loss) self.logger.log_sub_batch_metrics(evaluation) if accumulate_tokens_batch: n_tokens_accumulate_batch += batch.sequence_lengths.sum().item() # Gradient accumulation: only update gradients every num_simulated_batches step if not accumulate_tokens_batch and it % num_simulated_batches == (num_simulated_batches - 1) \ or accumulate_tokens_batch and n_tokens_accumulate_batch > simulated_batch_size: if accumulate_tokens_batch: n_tokens_accumulate_batch = 0 train_step += 1 total_norm = 0 for p in self.model_lm.parameters(): if p.grad is not None: param_norm = p.grad.data.norm(2) total_norm += param_norm.item() ** 2 total_norm = total_norm ** (1. / 2) self.logger.log_metrics({'gradient_norm': total_norm}, train_step * simulated_batch_size) self.optimizer.step() self.optimizer.zero_grad() if self.scheduler: if not hasattr(self.scheduler, "total_steps") or train_step < self.scheduler.total_steps - 1: self.scheduler.step() self.logger.log_metrics({'lr': self.scheduler.get_lr()[0]}, train_step * simulated_batch_size) # Send train metrics to observers self.logger.flush_batch_metrics(train_step * simulated_batch_size) # Evaluate on validation set if self.use_validation and validate_every and train_step % validate_every == 0: t.measure() self.model_lm.eval() with torch.no_grad(): for validation_batch in islice(dataloader_validation, num_simulated_batches_valid): validation_batch = batch_filter_distances(validation_batch, self.relative_distances) validation_batch = batch_to_device(validation_batch, self.device) output = self.model_lm.forward_batch(validation_batch).cpu() validation_batch = batch_to_device(validation_batch, "cpu") evaluation = self._evaluate_predictions(output.logits, validation_batch.labels, loss=output.loss, partition='valid') self.logger.log_sub_batch_metrics(evaluation) self.logger.log_text("predicted words/validation", str(self._decode_predicted_words(output, validation_batch))) self.logger.log_text("labels/validation", str(self._decode_labels(validation_batch))) self.model_lm.train() self.logger.flush_batch_metrics(step=train_step * simulated_batch_size) self.logger.log_time(t.measure() / simulated_batch_size_valid, "valid_seconds/sample", train_step * simulated_batch_size) if persistent_snapshot_every and (it + 1) % persistent_snapshot_every == 0: snapshot_iteration = it + 1 self.logger.info(f"Storing model params into snapshot-{snapshot_iteration}") self.model_manager.save_snapshot(run_id, self.model_lm.state_dict(), snapshot_iteration) dataset = self.dataset_validation_creator(False) score = self.evaluate(islice(dataset.to_dataloader(), int(max_validation_samples / batch_size)), train_step * simulated_batch_size, 'valid_full') if f"micro_f1_score/valid_full" in self.logger.sub_batch_metrics: score_name = 'micro-F1' else: score_name = 'F1' self.logger.info(f"Full evaluation yielded {score} {score_name}") if not early_stopping.evaluate(score, snapshot_iteration): self.logger.info(f"Last {early_stopping_patience} evaluations did not improve performance. " f"Stopping run") break progress_bar.update() if progress_bar.n >= progress_bar.total: progress_bar = tqdm(total=int(self.data_manager.approximate_total_samples() / batch_size)) epoch += 1 progress_bar.set_description(f"Epoch {epoch}") t.measure() self._handle_shutdown() def _train_step(self, batch, num_simulated_batches): batch = batch_to_device(batch, self.device) output_gpu = self.model_lm.forward_batch(batch) # Gradient accumulation: every batch contributes only a part of the total gradient (output_gpu.loss / num_simulated_batches).backward() output_cpu = output_gpu.cpu() del output_gpu del batch return output_cpu def _evaluate_predictions(self, logits, labels, loss=None, partition='train'): evaluation = dict() for metric_name, metric_fn in self.metrics.items(): evaluation[f"{metric_name}/{partition}"] = metric_fn(logits, labels) if loss: evaluation[f"loss/{partition}"] = loss.item() return evaluation def evaluate(self, dataset, step, partition='valid'): # Evaluate on validation set self.model_lm.eval() predictions = [] labels = [] with torch.no_grad(): for validation_batch in dataset: validation_batch = batch_filter_distances(validation_batch, self.relative_distances) validation_batch = batch_to_device(validation_batch, self.device) output = self.model_lm.forward_batch(validation_batch).cpu() validation_batch = batch_to_device(validation_batch, "cpu") predictions.extend(output.logits.argmax(-1)) labels.extend(validation_batch.labels) evaluation = self._evaluate_predictions(output.logits, validation_batch.labels, loss=output.loss, partition=partition) self.logger.log_sub_batch_metrics(evaluation) self.logger.log_text("predicted words/validation", str(self._decode_predicted_words(output, validation_batch))) self.logger.log_text("labels/validation", str(self._decode_labels(validation_batch))) self.model_lm.train() if f"micro_f1_score/{partition}" in self.logger.sub_batch_metrics: score = mean(self.logger.sub_batch_metrics[f"micro_f1_score/{partition}"]) else: score = mean(self.logger.sub_batch_metrics[f"f1_score/{partition}"]) self.logger.flush_batch_metrics(step=step) return score def _decode_predicted_words(self, model_out, batch): method_name_vocab = self.method_name_vocab if self.use_separate_vocab else self.word_vocab if hasattr(self, 'use_pointer_network') and self.use_pointer_network: extended_vocab_reverse = {idx: word for word, idx in batch.extended_vocabulary[0].items()} predicted_sub_tokens = ((predicted_sub_token.argmax().item(), predicted_sub_token.max().item()) for predicted_sub_token in model_out.logits[0][0]) return [ (extended_vocab_reverse[st] if st in extended_vocab_reverse else method_name_vocab.reverse_lookup(st), f"{value:0.2f}") for st, value in predicted_sub_tokens] else: return [(method_name_vocab.reverse_lookup(predicted_sub_token.argmax().item()), f"{predicted_sub_token.max().item():0.2f}") for predicted_sub_token in model_out.logits[0][0]] def _decode_labels(self, batch): method_name_vocab = self.method_name_vocab if self.use_separate_vocab else self.word_vocab if hasattr(self, 'use_pointer_network') and self.use_pointer_network: extended_vocab_reverse = {idx: word for word, idx in batch.extended_vocabulary[0].items()} label_tokens = (sub_token_label.item() for sub_token_label in batch.labels[0][0]) return [extended_vocab_reverse[lt] if lt in extended_vocab_reverse else method_name_vocab.reverse_lookup(lt) for lt in label_tokens] else: return [method_name_vocab.reverse_lookup(sub_token_label.item()) for sub_token_label in batch.labels[0][0]] def get_dataloader(self, split: str, batch_size: int): assert split == 'train' or split == 'validation' if split == 'train': ds = self.dataset_train elif split == 'validation': ds = self.dataset_validation dl = DataLoader(ds, batch_size=batch_size, num_workers=0, collate_fn=ds.collate_fn) dl = DataLoaderWrapper(dl) return BufferedDataManager(dl) def _handle_shutdown(self, sig=None, frame=None): self.dataset_train.data_manager.shutdown() self.dataset_validation.data_manager.shutdown() sys.exit(0)