def __update_label_info(unlabeled_bundle, lbls_list, soft_lbl_list, do_softmax, temperature): if len(lbls_list) == 1: for tw_ind, _ in enumerate(unlabeled_bundle.tws): unlabeled_bundle.input_y[0][tw_ind] = lbls_list[0][tw_ind] if do_softmax: unlabeled_bundle.input_y_row[0][tw_ind] = \ F.softmax(torch.tensor(soft_lbl_list[0][tw_ind]) / temperature).numpy().tolist() else: unlabeled_bundle.input_y_row[0][tw_ind] = soft_lbl_list[0][ tw_ind] elif len(lbls_list) == 2: for tw_ind, _ in enumerate(unlabeled_bundle.tws): if do_softmax: y_row = F.softmax(torch.tensor(soft_lbl_list[0][tw_ind]) / temperature).numpy() + \ F.softmax(torch.tensor(soft_lbl_list[1][tw_ind]) / temperature).numpy() else: y_row = np.array(soft_lbl_list[0][tw_ind]) + np.array( soft_lbl_list[1][tw_ind]) y_row = (y_row / 2).tolist() unlabeled_bundle.input_y_row[0][tw_ind] = y_row unlabeled_bundle.input_y[0][tw_ind] = y_row.index(max(y_row)) ELib.PASS() else: raise Exception('not implemented function!') ELib.PASS()
def train(self, train_bundle_list, valid_bundle_list = None, weighted_instance_loss = False, input_mode = EInputListMode.sequential, setup_learning_tools=True, extra_scheduled_trainset_size=0, extra_scheduled_epochs=0, customized_optimizer_params=None, report_number_of_intervals=20, switch_on_train_mode=True, train_shuffle=True, train_drop_last=True, balance_batch_mode_list=None, minimum_train_loss=None): if len(train_bundle_list) == 1 and len(train_bundle_list[0].tws) < self.config.batch_size: return ## init self.bert_classifier.to(self.config.device) self.bert_classifier.zero_grad() if setup_learning_tools: ## caveat: if you have called train() before this will reset the learning rate and the scheduler! self.__setup_learning_tools(train_bundle_list, weighted_instance_loss, input_mode, extra_scheduled_trainset_size, extra_scheduled_epochs, customized_optimizer_params) train_dt_list, train_tasks = self.get_datasets_and_tasks(train_bundle_list) valid_dt_list, valid_tasks = self.get_datasets_and_tasks(valid_bundle_list, self.config.early_stopping_patience) ## main loop self.early_stopped_epoch = -1 self.train_loss_early_stopped_epoch = -1 for cur_ep in range(math.ceil(self.config.epoch_count)): self.current_train_epoch = cur_ep self.bert_classifier.epoch_index += 1 # to track the overall number inside the classifier ## train if switch_on_train_mode: self.bert_classifier.train() else: self.bert_classifier.eval() self.__train_one_epoch(train_dt_list, train_tasks, input_mode, weighted_instance_loss, report_number_of_intervals, train_shuffle, train_drop_last, balance_batch_mode_list) ## validation with torch.no_grad(): stopping_valid_task = self.__validate_one_epoch(valid_bundle_list, valid_dt_list, valid_tasks, weighted_instance_loss) ## post process epoch self.__print_epoch_results(cur_ep + 1, self.config.epoch_count, train_tasks, valid_tasks) if self.config.take_train_checkpoints and (cur_ep + 1) % self.config.train_checkpoint_interval == 0: print('saving checkpoint...') self.save(str(cur_ep + 1)) if stopping_valid_task is not None: print('stopped early by \''+ stopping_valid_task[0] + '\'. restored the model of epoch {}'. format(stopping_valid_task[1].learning_state.best_index + 1)) self.early_stopped_epoch = stopping_valid_task[1].learning_state.best_index + 1 break if minimum_train_loss is not None: for cur_task in train_tasks.items(): if cur_task[1].size > 0 and cur_task[1].loss <= minimum_train_loss: self.train_loss_early_stopped_epoch = cur_ep break if self.train_loss_early_stopped_epoch >= 0: break gc.collect() ELib.PASS() ## save it if needed if self.config.take_train_checkpoints and (cur_ep + 1) % self.config.train_checkpoint_interval != 0: print('saving checkpoint...') self.save(str(cur_ep + 1)) self.bert_classifier.cpu() ELib.PASS() return train_tasks, valid_tasks
def __get_query_vec(self, tokens, queries): result = [0] * len(tokens) for cur_q in queries: q_tokens = self.tokenizer.tokenize(cur_q) q_ind = self.__find_sublist(q_tokens, tokens) if q_ind >= 0: result[q_ind] = 1 ELib.PASS() ELib.PASS() return result
def run(cmd, per_query, train_path, valid_path_nullable, test_path_nullable, unlabeled_path_nullable, model_path, model_path_2, lm_model_path, t_lbl_path_1, t_lbl_path_2, output_dir, device, device_2, seed, train_sample, unlabeled_sample): cmd = EPretrainCMD.get(cmd) lc = ELblConf( 0, 1, [ELbl(0, EVar.LblNonEventHealth), ELbl(1, EVar.LblEventHealth) ]) # mapping the dataset labels to negative and postivie if not os.path.exists(output_dir): os.makedirs(output_dir) queries = [None] if per_query: queries = ETweet.get_queries( ETweet.load(train_path, ELoadType.none)) for q_ind, cur_query in enumerate(queries): if cur_query is not None: print('>>>>>>>> "' + cur_query + '" began') if cmd == EPretrainCMD.bert_reg: config = EBertConfig.get_config(cmd, EBertCLSType.none, model_path, model_path_2, lm_model_path, t_lbl_path_1, t_lbl_path_2, output_dir, 5, device, device_2, seed, cur_query, gradient_checkpointing=False, check_early_stopping=False) train_bundle, valid_bundle, test_bundle, unlabeled_bundle = EInputBundle.get_data( config.label_count, lc, train_path, valid_path_nullable, test_path_nullable, unlabeled_path_nullable, cur_query) train_bundle = EPretrainProj.__stratified_sample_from_bundle( train_bundle, lc, config.seed, train_sample) EPretrainProj.__sample_from_unlabeled(unlabeled_bundle, unlabeled_sample, config.seed) EPretrainProj.__run_twin_net(config, lc, cur_query, train_bundle, valid_bundle, test_bundle, unlabeled_bundle) ELib.PASS() ELib.PASS()
def __validate_one_epoch(self, valid_bundle_list, valid_dt_list, valid_tasks, weighted_instance_loss): stopping_valid_task = None if valid_bundle_list is not None: self.bert_classifier.eval() [cur_task[1].reset() for cur_task in valid_tasks.items()] for dt_ind, cur_dt in enumerate(valid_dt_list): batches = self.generate_batches([cur_dt], self.config, False, False, self.current_train_epoch, EInputListMode.sequential) for ba_ind, cur_batch in enumerate(batches): outcome = self.bert_classifier(cur_batch, False) self.__process_loss(outcome, cur_batch, valid_tasks, False, weighted_instance_loss) self.delete_batch_from_gpu(cur_batch, EInputListMode.sequential) del cur_batch, outcome for cur_task in valid_tasks.items(): cur_task[1].loss /= cur_task[1].size cur_task[1].f1 = ELib.calculate_f1(cur_task[1].lbl_true, cur_task[1].lbl_pred) ################ checks early stopping only if the model does not have hooks ## deepcopy() cannot copy hooks! fix it later... if self.config.check_early_stopping and len(self.bert_classifier.logs) == 0: for cur_task in valid_tasks.items(): if cur_task[1].learning_state.should_stop( cur_task[1].loss, self.bert_classifier, self.config.device): self.bert_classifier.cpu() self.bert_classifier = cur_task[1].learning_state.best_model stopping_valid_task = cur_task break return stopping_valid_task
def __str__(self): result = "" result += self.Tweetid result += "\t" result += str(self.Label) result += "\t" result += self.Userid.ljust(ETweet.useridAlign, " ") tm = ELib.normalizeTime(self.Time) result += tm result += "\t" result += str(self.ReplyCount) result += "\t" result += str(self.LikeCount) result += "\t" result += str(self.RetweetCount) result += "\t" if self.Query != ETweet.tokenDummyQuery: result += self.Query else: if len(self.QueryList) > 0: for q in self.QueryList: result += '|' + q else: result += '|' result += "\t" result += self.Text return result.strip()
def close_logs(self): for cur_name, cur_log in self.logs.items(): cur_log.close() for cur_h in self.hooksForward: cur_h.remove() for cur_h in self.hooksBackward: cur_h.remove() ELib.PASS()
def populate_bundle(bundle, lc, size, seed): size_needed = size - len(bundle.tws) ratio_needed = float(size_needed) / len(bundle.tws) tws = ETweet.random_stratified_sample(bundle.tws, lc, ratio_needed, seed, True) if len(tws) + len(bundle.tws) > size: tws = tws[:len(tws) - 1] EInputBundle.append(bundle, bundle, tws) ELib.PASS()
def __print_inst(self, tokens, tokens_len, tokens_ids, query_vec): tokens = ['['] + tokens + [']'] + ['?'] * (self.max_seq - tokens_len) for ind, _ in enumerate(tokens): print( str(ind) + ': ' + tokens[ind].ljust(7) + ' [' + str(tokens_ids[ind]).ljust(5) + ']' + '(' + str(query_vec[ind]) + ')') print() ELib.PASS()
def __init__(self, config): super(EBertClassifierSimple, self).__init__() self.config = config self.bert_layer = BertModel(self.config.bert_config) self.last_dropout_layer = nn.Dropout(self.config.dropout_prob) self.last_layer = nn.Linear(self.config.bert_config.hidden_size, self.config.label_count) self.apply(self._init_weights) ELib.PASS()
def __hook_backward(self, module, grad_input, grad_output): if self.hook_activated and self.train_step % self.hook_interval == 0: name = self.__get_module_name(module) self.logs[name].add_histogram( 'out-gradients', grad_output[0].to('cpu').detach().numpy(), self.train_step) self.logs[name].add_histogram( 'bias-gradients', grad_input[0].to('cpu').detach().numpy(), self.train_step) ELib.PASS()
def convert_tags_to_tokens(dep_tags): c_result = EFeat1Gram.convert_lines_to_tokens(dep_tags) result = [] for tok in c_result: if not ELib.is_delimiter(tok.Text): result.append(tok) if tok.Text[0] == '#': tok.Text = tok.Text[1:] EFeat1Gram.build_tweet_trees(result) return result
def test(self, test_bundle, return_output_vecs=False, weighted_instance_loss=False, print_perf=True, title=None, report_number_of_intervals=20, return_output_vecs_get_details=True): if len(test_bundle.task_list) > 1: print('only one task is allowed for testing') return None if len(test_bundle.tws) == 0: return list(), list(), list(), list() if title is None: title = '' else: title += ' ' self.bert_classifier.to(self.config.device) self.bert_classifier.zero_grad() self.bert_classifier.eval() self.setup_objective(weighted_instance_loss) test_dt = EBertDataset(test_bundle, self.tokenizer, self.config.max_seq) batches = self.generate_batches([test_dt], self.config, False, False, 0, EInputListMode.sequential) result_vecs = list() result_vecs_detail = list() tasks = {test_bundle.task_list[0] : ETaskState(test_bundle.task_list[0])} print(title + 'labeling ', end=' ', flush=True) with torch.no_grad(): for ba_ind, cur_batch in enumerate(batches): outcome = self.bert_classifier(cur_batch, False) self.__process_loss(outcome, cur_batch, tasks, False, weighted_instance_loss) if return_output_vecs: result_vecs.extend(self.bert_classifier.output_vecs) if self.bert_classifier.output_vecs_detail is not None and return_output_vecs_get_details: result_vecs_detail.extend(self.bert_classifier.output_vecs_detail) if ELib.progress_made(ba_ind, cur_batch['batch_count'], report_number_of_intervals): print(ELib.progress_percent(ba_ind, cur_batch['batch_count']), end=' ', flush=True) self.delete_batch_from_gpu(cur_batch, EInputListMode.sequential) del cur_batch, outcome print() task_out = tasks[test_bundle.task_list[0]] task_out.loss /= task_out.size perf = ELib.calculate_metrics(task_out.lbl_true, task_out.lbl_pred) if print_perf: print('Test Results L1> Loss: {:.3f} F1: {:.3f} Pre: {:.3f} Rec: {:.3f}'.format( task_out.loss, perf[0], perf[1], perf[2]) + '\t\t' + ELib.get_time()) self.bert_classifier.cpu() return task_out.lbl_pred, task_out.logits, [result_vecs, result_vecs_detail], perf
def __init__(self, task_list, input_x, input_y, input_y_row, queries, input_weight, input_meta, tws): self.task_list = task_list self.input_x = input_x self.input_y = input_y self.input_y_row = input_y_row self.queries = queries self.input_weight = input_weight self.input_meta = input_meta self.tws = tws ELib.PASS()
def __load_pretrained_bert_module(module, config, loaded): if type(module) is BertModel: if module not in loaded: module.load_state_dict( EBertClassifier.__load_pretrained_bert_layer( config).state_dict()) loaded.add(module) else: print('already loaded') ELib.PASS() elif isinstance(module, EBertModelWrapper): if module not in loaded: module.bert_layer.load_state_dict( EBertClassifier.__load_pretrained_bert_layer( config).state_dict()) loaded.add(module) else: print('already loaded') ELib.PASS() else: print(colored('unknown bert module to load', 'red'))
def load(filePath, load_type, tweet_file=True): try: if type(load_type) == bool: print('change param to LoadType') exit(1) result = [] file = open(filePath, "r", encoding="utf-8") lines = file.readlines() file.close() if load_type == ELoadType.pos_tagger: pass else: if tweet_file: for ind, line in enumerate(lines): tw = ETweet(line) result.append(tw) if (ind + 1) % 1000000 == 0: print((ind + 1), ' reading lines') if (ind + 1) > 1000000 and ELoadType.none != load_type: print('reading tokens') tags = None if load_type == ELoadType.stored_tags or \ load_type == ELoadType.stored_tags_all_DontUseThis: tags = EFeat1Gram.read_dep_tags(filePath + "-tags") elif load_type == ELoadType.stored_injected_tags or \ load_type == ELoadType.stored_injected_tags_all_DontUseThis: tags = EFeat1Gram.read_dep_tags(filePath + "-tags-synthesized") for ind, tw in enumerate(result): if load_type == ELoadType.stored_tags_all_DontUseThis or \ load_type == ELoadType.stored_injected_tags_all_DontUseThis: tw.ETokens = EFeat1Gram.convert_all_tags_to_tokens( tags[ind]) elif load_type == ELoadType.stored_tags or \ load_type == ELoadType.stored_injected_tags: tw.ETokens = EFeat1Gram.convert_tags_to_tokens( tags[ind]) if (ind + 1 ) % 500000 == 0 and ELoadType.none != load_type: print((ind + 1), ' constructing tokens') else: for ind, line in enumerate(lines): tokens = line.strip().split('\t') tw = ETweet.__text_to_tweet_object(tokens) result.append(tw) ELib.PASS() except Exception as err: print( colored( 'Error in loading: "{}"\n\n{}'.format(filePath, str(err)), 'red')) sys.exit(1) return result
def save(self, prefix=''): torch.save(self.bert_classifier.state_dict(), os.path.join(self.config.output_dir, prefix + 'pytorch_model.bin')) if os.path.join(self.config.model_path, 'config.json') != \ os.path.join(self.config.output_dir, prefix + 'config.json'): shutil.copyfile(os.path.join(self.config.model_path, 'config.json'), os.path.join(self.config.output_dir, prefix + 'config.json')) if os.path.join(self.config.model_path, 'vocab.txt') != \ os.path.join(self.config.output_dir, prefix + 'vocab.txt'): shutil.copyfile(os.path.join(self.config.model_path, 'vocab.txt'), os.path.join(self.config.output_dir, prefix + 'vocab.txt')) ELib.PASS()
def __init__(self): super(EClassifier, self).__init__() self.logs = dict() self.train_state = 0 self.epoch_index = -1 self.train_step = -1 self.hooked_modules = dict() self.hook_interval = None self.hook_activated = None self.hooksForward = list() self.hooksBackward = list() ELib.PASS()
def __train_one_epoch(self, train_dt_list, train_tasks, input_mode, weighted_instance_loss, report_number_of_intervals, train_shuffle, train_drop_last, balance_batch_mode_list): batches = self.generate_batches(train_dt_list, self.config, train_shuffle, train_drop_last, self.current_train_epoch, input_mode, balance_batch_mode_list) [cur_task[1].reset() for cur_task in train_tasks.items()] for ba_ind, cur_batch in enumerate(batches): self.bert_classifier.train_step += 1 # to track the overall number inside the classifier while True: outcome = self.bert_classifier(cur_batch, False) self.__process_loss(outcome, cur_batch, train_tasks, True, weighted_instance_loss) if not self.delay_optimizer: break if ELib.progress_made(ba_ind, cur_batch['batch_count'], report_number_of_intervals): print(ELib.progress_percent(ba_ind, cur_batch['batch_count']), end=' ', flush=True) self.delete_batch_from_gpu(cur_batch, input_mode) del cur_batch, outcome ## in case there are multiple models and their losses are heavy (in terms of memory) ## you can call 'self.sync_obj.lock_loss_calculation.acquire()' in 'self.custom_train_loss_func()' ## This way the losses are calculated one by one and after that the models are re-synched if self.sync_obj is not None and self.sync_obj.lock_loss_calculation.locked(): ## wait for the other models to arrive if self.sync_obj.sync_counter == self.sync_obj.model_count: self.sync_obj.reset() self.sync_obj.sync_counter += 1 self.sync_obj.lock_loss_calculation.release() while self.sync_obj.sync_counter < self.sync_obj.model_count: self.sleep() # pprint(vars(self)) # ELib.PASS() ## if there are multiple models avoid double printing the newline if self.sync_obj is None: print() elif self.model_id == 0: print() ## calculate the metric averages in the epoch for cur_task in train_tasks.items(): if cur_task[1].size > 0: cur_task[1].loss /= cur_task[1].size cur_task[1].f1 = ELib.calculate_f1(cur_task[1].lbl_true, cur_task[1].lbl_pred) ELib.PASS()
def __sample_from_unlabeled(unlabeled_bundle, count, seed): random.seed(seed) count = min(count, len(unlabeled_bundle.tws)) sample_set = set() while True: tw_ind = random.randint(0, len(unlabeled_bundle.tws) - 1) if unlabeled_bundle.tws[tw_ind] not in sample_set: sample_set.add(unlabeled_bundle.tws[tw_ind]) if len(sample_set) >= count: break drop_list = [tw for tw in unlabeled_bundle.tws if tw not in sample_set] EInputBundle.remove(unlabeled_bundle, drop_list) ELib.PASS()
def __hook_forward(self, module, input, output): if self.hook_activated and self.train_step % self.hook_interval == 0: name = self.__get_module_name(module) self.logs[name].add_histogram('activations', output.to('cpu').detach().numpy(), self.train_step) self.logs[name].add_histogram( 'weights', module.weight.data.to('cpu').detach().numpy(), self.train_step) self.logs[name].add_histogram( 'bias', module.bias.data.to('cpu').detach().numpy(), self.train_step) ELib.PASS()
def __print_epoch_results(self, ep_no, all_ep, train_tasks, valid_tasks): result = 'epoch: {}/{}> '.format(ep_no, all_ep) for cur_task in train_tasks.items(): if cur_task[1].size > 0: result += '|T: {}, tr-loss: {:.3f}, tr-f1: {:.3f} '.format( cur_task[0], cur_task[1].loss, cur_task[1].f1) if valid_tasks is not None: for cur_task in valid_tasks.items(): if cur_task[1].size > 0: result += '|T: {}, va-loss: {:.3f}, va-f1: {:.3f} '.format( cur_task[0], cur_task[1].loss, cur_task[1].f1) result += '\t' + ELib.get_time() print(result)
def get_data(label_count, lc, train_path_nullable, valid_path_nullable, test_path_nullable, unlabeled_path_nullable, filter_query=None, remove_unlabeled_test_tweets=False, tokenize_by_etokens=False, pivot_query=None, max_set_length=0): train_bundle, valid_bundle, test_bundle, unlabeled_bundle, = [None] * 4 if train_path_nullable is not None: tws_train = ETweet.load(train_path_nullable, ELoadType.none, tweet_file=False) train_bundle = EInputBundle.get_input_bundle( [EVar.DefaultTask], tws_train, lc, filter_query, tokenize_by_etokens, pivot_query, label_count, max_set_length) if valid_path_nullable is not None: tws_valid = ETweet.load(valid_path_nullable, ELoadType.none, tweet_file=False) valid_bundle = EInputBundle.get_input_bundle( [EVar.DefaultTask], tws_valid, lc, filter_query, tokenize_by_etokens, pivot_query, label_count, max_set_length) if test_path_nullable is not None: tws_test = ETweet.load(test_path_nullable, ELoadType.none, tweet_file=False) if remove_unlabeled_test_tweets: print('removing unlabeled test tweets ...') ind = 0 while ind < len(tws_test): if tws_test[ind].Label == 0: del tws_test[ind] ELib.PASS() else: ind += 1 test_bundle = EInputBundle.get_input_bundle( [EVar.DefaultTask], tws_test, lc, filter_query, tokenize_by_etokens, pivot_query, label_count, max_set_length) if unlabeled_path_nullable is not None: tws_unlabeled = ETweet.load(unlabeled_path_nullable, ELoadType.none, tweet_file=False) unlabeled_bundle = EInputBundle.get_input_bundle( [EVar.DefaultTask], tws_unlabeled, lc, filter_query, tokenize_by_etokens, pivot_query, label_count, max_set_length) return train_bundle, valid_bundle, test_bundle, unlabeled_bundle
def __init__(self, seed, model_count, synchronized_bundle_indices=None): self.seed = seed self.model_count = model_count self.lock_dataset = threading.Lock() self.lock_batch = threading.Lock() self.lock_loss_calculation = threading.Lock() self.sync_bundle_indices = synchronized_bundle_indices self.sync_bundle_batches = dict() self.sync_bundle_batches_sizes = list() self.sync_counter = model_count # this is needed for the first iteration self.sync_list = list() self.meta_list = list() for ind in range(model_count): self.sync_list.append(queue.Queue()) self.meta_list.append(None) ELib.PASS()
def __align_tokens_in_tweet(pivot, pivot_pos, bert_tokens_rec, bert_tokens_rec_ind): span = 1 while True: phrase = ''.join([ entry[0] for entry in bert_tokens_rec[bert_tokens_rec_ind:bert_tokens_rec_ind + span] ]) if pivot == phrase: return span, phrase if pivot_pos == 'U' and phrase == 'www': return span, 'www' if phrase == '[UNK]': return 1, pivot span += 1 if len(bert_tokens_rec) < bert_tokens_rec_ind + span: return None ELib.PASS()
def get_input_bundle(task_list, tws, lc, filter_query, tokenize_by_etokens, pivot_query, label_count, max_set_length=0): if filter_query is not None: tws = ETweet.filter_by_query(tws, filter_query) if 0 < max_set_length: tws = tws[:max_set_length] result_x = list() result_y = list() [result_y.append([]) for _ in range(len(task_list))] result_y_row = list() [result_y_row.append([]) for _ in range(len(task_list))] queries = list() weights = list() meta = list() for tw_ind, cur_tw in enumerate(tws): tokenized = ELib.tokenize_tweet_text(cur_tw, True, tokenize_by_etokens, pivot_query, cur_tw.QueryList) result_x.append(tokenized) lbl = lc.get_correct_new_label(cur_tw.Label) [result_y[t_ind].append(lbl) for t_ind in range(len(task_list))] [ result_y_row[t_ind].append([0 for _ in range(label_count)]) for t_ind in range(len(task_list)) ] for t_ind in range(len(task_list)): result_y_row[t_ind][-1][lbl] = 1 if pivot_query is not None: queries.append([pivot_query]) elif cur_tw.Query == ETweet.tokenDummyQuery: queries.append(cur_tw.QueryList) else: queries.append([cur_tw.Query]) weights.append(1.0) meta.append(0.0) result = EInputBundle(task_list, result_x, result_y, result_y_row, queries, weights, meta, tws) return result
def save_tweets_as_text_file(start_id, tws, file_path): result = '' for cur_tw in tws: if cur_tw.Query != ETweet.tokenDummyQuery: query = cur_tw.Query else: query = '' if len(cur_tw.QueryList) > 0: for q in cur_tw.QueryList: query += '|' + q else: query += '|' result += '{}\t{}\t{}\t{}\n'.format( str(start_id).zfill(7), str(cur_tw.Label), query, cur_tw.Text.strip()) start_id += 1 with io.open(file_path, 'w', encoding='utf-8') as ptr: ptr.write(result) ELib.PASS()
def remove(bundle, to_remove_tws): id_dict = dict() for cur_ind, cur_tw in enumerate(bundle.tws): id_dict[cur_tw.Tweetid] = cur_ind to_delete = list() for cur_tw in to_remove_tws: cur_ind = id_dict[cur_tw.Tweetid] to_delete.append(cur_ind) to_delete.sort(reverse=True) for cur_ind in to_delete: del bundle.input_x[cur_ind] for y_ind in range(len(bundle.input_y)): del bundle.input_y[y_ind][cur_ind] del bundle.input_y_row[y_ind][cur_ind] del bundle.queries[cur_ind] del bundle.input_weight[cur_ind] del bundle.input_meta[cur_ind] del bundle.tws[cur_ind] ELib.PASS()
def remove_module_from_optimizer(self, module): if isinstance(module, nn.ModuleList): print(colored('>>> cannot handle ModuleList to delete from the optimizer! <<<', 'red')) sys.exit(1) params = list(module.parameters()) found = None for cur_group in self.optimizer.param_groups: try: if params[0] in cur_group['params']: found = cur_group break except: pass if found is not None: self.removed_modules.append([module, found['lr'], found['weight_decay'], found['eps']]) self.optimizer.param_groups.remove(found) else: print(colored('>>> module was not found for deletion in the optimizer! <<< \n' 'perhaps you have passed "customized_params" to setup_optimizer()', 'red')) sys.exit(1) ELib.PASS()
def load_pretrained_bert_modules(modules, config): mod_list = modules if type(modules) is not OrderedDict: mod_list = OrderedDict([('reconfig', modules)]) loaded = set() for cur_module in mod_list.items(): if type(cur_module[1]) is BertModel or isinstance( cur_module[1], EBertModelWrapper): print('{}: '.format(cur_module[0]), end='', flush=True) EBertClassifier.__load_pretrained_bert_module( cur_module[1], config, loaded) elif type(cur_module[1]) is nn.ModuleList: for c_ind, cur_child_module in enumerate(cur_module[1]): if type(cur_child_module) is BertModel or isinstance( cur_child_module, EBertModelWrapper): print('{}[{}]: '.format(cur_module[0], c_ind), end='', flush=True) EBertClassifier.__load_pretrained_bert_module( cur_child_module, config, loaded) ELib.PASS()