def __update_label_info(unlabeled_bundle, lbls_list, soft_lbl_list, do_softmax, temperature): if len(lbls_list) == 1: for tw_ind, _ in enumerate(unlabeled_bundle.tws): unlabeled_bundle.input_y[0][tw_ind] = lbls_list[0][tw_ind] if do_softmax: unlabeled_bundle.input_y_row[0][tw_ind] = \ F.softmax(torch.tensor(soft_lbl_list[0][tw_ind]) / temperature).numpy().tolist() else: unlabeled_bundle.input_y_row[0][tw_ind] = soft_lbl_list[0][ tw_ind] elif len(lbls_list) == 2: for tw_ind, _ in enumerate(unlabeled_bundle.tws): if do_softmax: y_row = F.softmax(torch.tensor(soft_lbl_list[0][tw_ind]) / temperature).numpy() + \ F.softmax(torch.tensor(soft_lbl_list[1][tw_ind]) / temperature).numpy() else: y_row = np.array(soft_lbl_list[0][tw_ind]) + np.array( soft_lbl_list[1][tw_ind]) y_row = (y_row / 2).tolist() unlabeled_bundle.input_y_row[0][tw_ind] = y_row unlabeled_bundle.input_y[0][tw_ind] = y_row.index(max(y_row)) ELib.PASS() else: raise Exception('not implemented function!') ELib.PASS()
def train(self, train_bundle_list, valid_bundle_list = None, weighted_instance_loss = False, input_mode = EInputListMode.sequential, setup_learning_tools=True, extra_scheduled_trainset_size=0, extra_scheduled_epochs=0, customized_optimizer_params=None, report_number_of_intervals=20, switch_on_train_mode=True, train_shuffle=True, train_drop_last=True, balance_batch_mode_list=None, minimum_train_loss=None): if len(train_bundle_list) == 1 and len(train_bundle_list[0].tws) < self.config.batch_size: return ## init self.bert_classifier.to(self.config.device) self.bert_classifier.zero_grad() if setup_learning_tools: ## caveat: if you have called train() before this will reset the learning rate and the scheduler! self.__setup_learning_tools(train_bundle_list, weighted_instance_loss, input_mode, extra_scheduled_trainset_size, extra_scheduled_epochs, customized_optimizer_params) train_dt_list, train_tasks = self.get_datasets_and_tasks(train_bundle_list) valid_dt_list, valid_tasks = self.get_datasets_and_tasks(valid_bundle_list, self.config.early_stopping_patience) ## main loop self.early_stopped_epoch = -1 self.train_loss_early_stopped_epoch = -1 for cur_ep in range(math.ceil(self.config.epoch_count)): self.current_train_epoch = cur_ep self.bert_classifier.epoch_index += 1 # to track the overall number inside the classifier ## train if switch_on_train_mode: self.bert_classifier.train() else: self.bert_classifier.eval() self.__train_one_epoch(train_dt_list, train_tasks, input_mode, weighted_instance_loss, report_number_of_intervals, train_shuffle, train_drop_last, balance_batch_mode_list) ## validation with torch.no_grad(): stopping_valid_task = self.__validate_one_epoch(valid_bundle_list, valid_dt_list, valid_tasks, weighted_instance_loss) ## post process epoch self.__print_epoch_results(cur_ep + 1, self.config.epoch_count, train_tasks, valid_tasks) if self.config.take_train_checkpoints and (cur_ep + 1) % self.config.train_checkpoint_interval == 0: print('saving checkpoint...') self.save(str(cur_ep + 1)) if stopping_valid_task is not None: print('stopped early by \''+ stopping_valid_task[0] + '\'. restored the model of epoch {}'. format(stopping_valid_task[1].learning_state.best_index + 1)) self.early_stopped_epoch = stopping_valid_task[1].learning_state.best_index + 1 break if minimum_train_loss is not None: for cur_task in train_tasks.items(): if cur_task[1].size > 0 and cur_task[1].loss <= minimum_train_loss: self.train_loss_early_stopped_epoch = cur_ep break if self.train_loss_early_stopped_epoch >= 0: break gc.collect() ELib.PASS() ## save it if needed if self.config.take_train_checkpoints and (cur_ep + 1) % self.config.train_checkpoint_interval != 0: print('saving checkpoint...') self.save(str(cur_ep + 1)) self.bert_classifier.cpu() ELib.PASS() return train_tasks, valid_tasks
def __get_query_vec(self, tokens, queries): result = [0] * len(tokens) for cur_q in queries: q_tokens = self.tokenizer.tokenize(cur_q) q_ind = self.__find_sublist(q_tokens, tokens) if q_ind >= 0: result[q_ind] = 1 ELib.PASS() ELib.PASS() return result
def run(cmd, per_query, train_path, valid_path_nullable, test_path_nullable, unlabeled_path_nullable, model_path, model_path_2, lm_model_path, t_lbl_path_1, t_lbl_path_2, output_dir, device, device_2, seed, train_sample, unlabeled_sample): cmd = EPretrainCMD.get(cmd) lc = ELblConf( 0, 1, [ELbl(0, EVar.LblNonEventHealth), ELbl(1, EVar.LblEventHealth) ]) # mapping the dataset labels to negative and postivie if not os.path.exists(output_dir): os.makedirs(output_dir) queries = [None] if per_query: queries = ETweet.get_queries( ETweet.load(train_path, ELoadType.none)) for q_ind, cur_query in enumerate(queries): if cur_query is not None: print('>>>>>>>> "' + cur_query + '" began') if cmd == EPretrainCMD.bert_reg: config = EBertConfig.get_config(cmd, EBertCLSType.none, model_path, model_path_2, lm_model_path, t_lbl_path_1, t_lbl_path_2, output_dir, 5, device, device_2, seed, cur_query, gradient_checkpointing=False, check_early_stopping=False) train_bundle, valid_bundle, test_bundle, unlabeled_bundle = EInputBundle.get_data( config.label_count, lc, train_path, valid_path_nullable, test_path_nullable, unlabeled_path_nullable, cur_query) train_bundle = EPretrainProj.__stratified_sample_from_bundle( train_bundle, lc, config.seed, train_sample) EPretrainProj.__sample_from_unlabeled(unlabeled_bundle, unlabeled_sample, config.seed) EPretrainProj.__run_twin_net(config, lc, cur_query, train_bundle, valid_bundle, test_bundle, unlabeled_bundle) ELib.PASS() ELib.PASS()
def close_logs(self): for cur_name, cur_log in self.logs.items(): cur_log.close() for cur_h in self.hooksForward: cur_h.remove() for cur_h in self.hooksBackward: cur_h.remove() ELib.PASS()
def __print_inst(self, tokens, tokens_len, tokens_ids, query_vec): tokens = ['['] + tokens + [']'] + ['?'] * (self.max_seq - tokens_len) for ind, _ in enumerate(tokens): print( str(ind) + ': ' + tokens[ind].ljust(7) + ' [' + str(tokens_ids[ind]).ljust(5) + ']' + '(' + str(query_vec[ind]) + ')') print() ELib.PASS()
def __init__(self, config): super(EBertClassifierSimple, self).__init__() self.config = config self.bert_layer = BertModel(self.config.bert_config) self.last_dropout_layer = nn.Dropout(self.config.dropout_prob) self.last_layer = nn.Linear(self.config.bert_config.hidden_size, self.config.label_count) self.apply(self._init_weights) ELib.PASS()
def populate_bundle(bundle, lc, size, seed): size_needed = size - len(bundle.tws) ratio_needed = float(size_needed) / len(bundle.tws) tws = ETweet.random_stratified_sample(bundle.tws, lc, ratio_needed, seed, True) if len(tws) + len(bundle.tws) > size: tws = tws[:len(tws) - 1] EInputBundle.append(bundle, bundle, tws) ELib.PASS()
def __hook_backward(self, module, grad_input, grad_output): if self.hook_activated and self.train_step % self.hook_interval == 0: name = self.__get_module_name(module) self.logs[name].add_histogram( 'out-gradients', grad_output[0].to('cpu').detach().numpy(), self.train_step) self.logs[name].add_histogram( 'bias-gradients', grad_input[0].to('cpu').detach().numpy(), self.train_step) ELib.PASS()
def __init__(self, task_list, input_x, input_y, input_y_row, queries, input_weight, input_meta, tws): self.task_list = task_list self.input_x = input_x self.input_y = input_y self.input_y_row = input_y_row self.queries = queries self.input_weight = input_weight self.input_meta = input_meta self.tws = tws ELib.PASS()
def __init__(self): super(EClassifier, self).__init__() self.logs = dict() self.train_state = 0 self.epoch_index = -1 self.train_step = -1 self.hooked_modules = dict() self.hook_interval = None self.hook_activated = None self.hooksForward = list() self.hooksBackward = list() ELib.PASS()
def save(self, prefix=''): torch.save(self.bert_classifier.state_dict(), os.path.join(self.config.output_dir, prefix + 'pytorch_model.bin')) if os.path.join(self.config.model_path, 'config.json') != \ os.path.join(self.config.output_dir, prefix + 'config.json'): shutil.copyfile(os.path.join(self.config.model_path, 'config.json'), os.path.join(self.config.output_dir, prefix + 'config.json')) if os.path.join(self.config.model_path, 'vocab.txt') != \ os.path.join(self.config.output_dir, prefix + 'vocab.txt'): shutil.copyfile(os.path.join(self.config.model_path, 'vocab.txt'), os.path.join(self.config.output_dir, prefix + 'vocab.txt')) ELib.PASS()
def load(filePath, load_type, tweet_file=True): try: if type(load_type) == bool: print('change param to LoadType') exit(1) result = [] file = open(filePath, "r", encoding="utf-8") lines = file.readlines() file.close() if load_type == ELoadType.pos_tagger: pass else: if tweet_file: for ind, line in enumerate(lines): tw = ETweet(line) result.append(tw) if (ind + 1) % 1000000 == 0: print((ind + 1), ' reading lines') if (ind + 1) > 1000000 and ELoadType.none != load_type: print('reading tokens') tags = None if load_type == ELoadType.stored_tags or \ load_type == ELoadType.stored_tags_all_DontUseThis: tags = EFeat1Gram.read_dep_tags(filePath + "-tags") elif load_type == ELoadType.stored_injected_tags or \ load_type == ELoadType.stored_injected_tags_all_DontUseThis: tags = EFeat1Gram.read_dep_tags(filePath + "-tags-synthesized") for ind, tw in enumerate(result): if load_type == ELoadType.stored_tags_all_DontUseThis or \ load_type == ELoadType.stored_injected_tags_all_DontUseThis: tw.ETokens = EFeat1Gram.convert_all_tags_to_tokens( tags[ind]) elif load_type == ELoadType.stored_tags or \ load_type == ELoadType.stored_injected_tags: tw.ETokens = EFeat1Gram.convert_tags_to_tokens( tags[ind]) if (ind + 1 ) % 500000 == 0 and ELoadType.none != load_type: print((ind + 1), ' constructing tokens') else: for ind, line in enumerate(lines): tokens = line.strip().split('\t') tw = ETweet.__text_to_tweet_object(tokens) result.append(tw) ELib.PASS() except Exception as err: print( colored( 'Error in loading: "{}"\n\n{}'.format(filePath, str(err)), 'red')) sys.exit(1) return result
def __load_pretrained_bert_module(module, config, loaded): if type(module) is BertModel: if module not in loaded: module.load_state_dict( EBertClassifier.__load_pretrained_bert_layer( config).state_dict()) loaded.add(module) else: print('already loaded') ELib.PASS() elif isinstance(module, EBertModelWrapper): if module not in loaded: module.bert_layer.load_state_dict( EBertClassifier.__load_pretrained_bert_layer( config).state_dict()) loaded.add(module) else: print('already loaded') ELib.PASS() else: print(colored('unknown bert module to load', 'red'))
def __hook_forward(self, module, input, output): if self.hook_activated and self.train_step % self.hook_interval == 0: name = self.__get_module_name(module) self.logs[name].add_histogram('activations', output.to('cpu').detach().numpy(), self.train_step) self.logs[name].add_histogram( 'weights', module.weight.data.to('cpu').detach().numpy(), self.train_step) self.logs[name].add_histogram( 'bias', module.bias.data.to('cpu').detach().numpy(), self.train_step) ELib.PASS()
def __sample_from_unlabeled(unlabeled_bundle, count, seed): random.seed(seed) count = min(count, len(unlabeled_bundle.tws)) sample_set = set() while True: tw_ind = random.randint(0, len(unlabeled_bundle.tws) - 1) if unlabeled_bundle.tws[tw_ind] not in sample_set: sample_set.add(unlabeled_bundle.tws[tw_ind]) if len(sample_set) >= count: break drop_list = [tw for tw in unlabeled_bundle.tws if tw not in sample_set] EInputBundle.remove(unlabeled_bundle, drop_list) ELib.PASS()
def get_data(label_count, lc, train_path_nullable, valid_path_nullable, test_path_nullable, unlabeled_path_nullable, filter_query=None, remove_unlabeled_test_tweets=False, tokenize_by_etokens=False, pivot_query=None, max_set_length=0): train_bundle, valid_bundle, test_bundle, unlabeled_bundle, = [None] * 4 if train_path_nullable is not None: tws_train = ETweet.load(train_path_nullable, ELoadType.none, tweet_file=False) train_bundle = EInputBundle.get_input_bundle( [EVar.DefaultTask], tws_train, lc, filter_query, tokenize_by_etokens, pivot_query, label_count, max_set_length) if valid_path_nullable is not None: tws_valid = ETweet.load(valid_path_nullable, ELoadType.none, tweet_file=False) valid_bundle = EInputBundle.get_input_bundle( [EVar.DefaultTask], tws_valid, lc, filter_query, tokenize_by_etokens, pivot_query, label_count, max_set_length) if test_path_nullable is not None: tws_test = ETweet.load(test_path_nullable, ELoadType.none, tweet_file=False) if remove_unlabeled_test_tweets: print('removing unlabeled test tweets ...') ind = 0 while ind < len(tws_test): if tws_test[ind].Label == 0: del tws_test[ind] ELib.PASS() else: ind += 1 test_bundle = EInputBundle.get_input_bundle( [EVar.DefaultTask], tws_test, lc, filter_query, tokenize_by_etokens, pivot_query, label_count, max_set_length) if unlabeled_path_nullable is not None: tws_unlabeled = ETweet.load(unlabeled_path_nullable, ELoadType.none, tweet_file=False) unlabeled_bundle = EInputBundle.get_input_bundle( [EVar.DefaultTask], tws_unlabeled, lc, filter_query, tokenize_by_etokens, pivot_query, label_count, max_set_length) return train_bundle, valid_bundle, test_bundle, unlabeled_bundle
def __init__(self, seed, model_count, synchronized_bundle_indices=None): self.seed = seed self.model_count = model_count self.lock_dataset = threading.Lock() self.lock_batch = threading.Lock() self.lock_loss_calculation = threading.Lock() self.sync_bundle_indices = synchronized_bundle_indices self.sync_bundle_batches = dict() self.sync_bundle_batches_sizes = list() self.sync_counter = model_count # this is needed for the first iteration self.sync_list = list() self.meta_list = list() for ind in range(model_count): self.sync_list.append(queue.Queue()) self.meta_list.append(None) ELib.PASS()
def __align_tokens_in_tweet(pivot, pivot_pos, bert_tokens_rec, bert_tokens_rec_ind): span = 1 while True: phrase = ''.join([ entry[0] for entry in bert_tokens_rec[bert_tokens_rec_ind:bert_tokens_rec_ind + span] ]) if pivot == phrase: return span, phrase if pivot_pos == 'U' and phrase == 'www': return span, 'www' if phrase == '[UNK]': return 1, pivot span += 1 if len(bert_tokens_rec) < bert_tokens_rec_ind + span: return None ELib.PASS()
def save_tweets_as_text_file(start_id, tws, file_path): result = '' for cur_tw in tws: if cur_tw.Query != ETweet.tokenDummyQuery: query = cur_tw.Query else: query = '' if len(cur_tw.QueryList) > 0: for q in cur_tw.QueryList: query += '|' + q else: query += '|' result += '{}\t{}\t{}\t{}\n'.format( str(start_id).zfill(7), str(cur_tw.Label), query, cur_tw.Text.strip()) start_id += 1 with io.open(file_path, 'w', encoding='utf-8') as ptr: ptr.write(result) ELib.PASS()
def remove(bundle, to_remove_tws): id_dict = dict() for cur_ind, cur_tw in enumerate(bundle.tws): id_dict[cur_tw.Tweetid] = cur_ind to_delete = list() for cur_tw in to_remove_tws: cur_ind = id_dict[cur_tw.Tweetid] to_delete.append(cur_ind) to_delete.sort(reverse=True) for cur_ind in to_delete: del bundle.input_x[cur_ind] for y_ind in range(len(bundle.input_y)): del bundle.input_y[y_ind][cur_ind] del bundle.input_y_row[y_ind][cur_ind] del bundle.queries[cur_ind] del bundle.input_weight[cur_ind] del bundle.input_meta[cur_ind] del bundle.tws[cur_ind] ELib.PASS()
def remove_module_from_optimizer(self, module): if isinstance(module, nn.ModuleList): print(colored('>>> cannot handle ModuleList to delete from the optimizer! <<<', 'red')) sys.exit(1) params = list(module.parameters()) found = None for cur_group in self.optimizer.param_groups: try: if params[0] in cur_group['params']: found = cur_group break except: pass if found is not None: self.removed_modules.append([module, found['lr'], found['weight_decay'], found['eps']]) self.optimizer.param_groups.remove(found) else: print(colored('>>> module was not found for deletion in the optimizer! <<< \n' 'perhaps you have passed "customized_params" to setup_optimizer()', 'red')) sys.exit(1) ELib.PASS()
def load_pretrained_bert_modules(modules, config): mod_list = modules if type(modules) is not OrderedDict: mod_list = OrderedDict([('reconfig', modules)]) loaded = set() for cur_module in mod_list.items(): if type(cur_module[1]) is BertModel or isinstance( cur_module[1], EBertModelWrapper): print('{}: '.format(cur_module[0]), end='', flush=True) EBertClassifier.__load_pretrained_bert_module( cur_module[1], config, loaded) elif type(cur_module[1]) is nn.ModuleList: for c_ind, cur_child_module in enumerate(cur_module[1]): if type(cur_child_module) is BertModel or isinstance( cur_child_module, EBertModelWrapper): print('{}[{}]: '.format(cur_module[0], c_ind), end='', flush=True) EBertClassifier.__load_pretrained_bert_module( cur_child_module, config, loaded) ELib.PASS()
def __train_one_epoch(self, train_dt_list, train_tasks, input_mode, weighted_instance_loss, report_number_of_intervals, train_shuffle, train_drop_last, balance_batch_mode_list): batches = self.generate_batches(train_dt_list, self.config, train_shuffle, train_drop_last, self.current_train_epoch, input_mode, balance_batch_mode_list) [cur_task[1].reset() for cur_task in train_tasks.items()] for ba_ind, cur_batch in enumerate(batches): self.bert_classifier.train_step += 1 # to track the overall number inside the classifier while True: outcome = self.bert_classifier(cur_batch, False) self.__process_loss(outcome, cur_batch, train_tasks, True, weighted_instance_loss) if not self.delay_optimizer: break if ELib.progress_made(ba_ind, cur_batch['batch_count'], report_number_of_intervals): print(ELib.progress_percent(ba_ind, cur_batch['batch_count']), end=' ', flush=True) self.delete_batch_from_gpu(cur_batch, input_mode) del cur_batch, outcome ## in case there are multiple models and their losses are heavy (in terms of memory) ## you can call 'self.sync_obj.lock_loss_calculation.acquire()' in 'self.custom_train_loss_func()' ## This way the losses are calculated one by one and after that the models are re-synched if self.sync_obj is not None and self.sync_obj.lock_loss_calculation.locked(): ## wait for the other models to arrive if self.sync_obj.sync_counter == self.sync_obj.model_count: self.sync_obj.reset() self.sync_obj.sync_counter += 1 self.sync_obj.lock_loss_calculation.release() while self.sync_obj.sync_counter < self.sync_obj.model_count: self.sleep() # pprint(vars(self)) # ELib.PASS() ## if there are multiple models avoid double printing the newline if self.sync_obj is None: print() elif self.model_id == 0: print() ## calculate the metric averages in the epoch for cur_task in train_tasks.items(): if cur_task[1].size > 0: cur_task[1].loss /= cur_task[1].size cur_task[1].f1 = ELib.calculate_f1(cur_task[1].lbl_true, cur_task[1].lbl_pred) ELib.PASS()
def __init__(self, cmd, cls_type, bert_config, label_count, model_path, model_path_2, lm_model_path, t_lbl_path_1, t_lbl_path_2, output_dir, device, device_2, dropout_prob, max_seq, batch_size, epoch_count, seed, learn_rate, early_stopping_patience, max_grad_norm, weight_decay, adam_epsilon, warmup_steps, train_by_log_softmax, training_log_softmax_weight, training_softmax_temperature, balance_batch_mode, take_train_checkpoints, train_checkpoint_interval, check_early_stopping): self.cmd = cmd self.cls_type = cls_type self.bert_config = bert_config self.label_count = label_count self.model_path = model_path self.model_path_2 = model_path_2 self.lm_model_path = lm_model_path self.t_lbl_path_1 = t_lbl_path_1 self.t_lbl_path_2 = t_lbl_path_2 self.output_dir = output_dir self.device = device self.device_2 = device_2 self.dropout_prob = dropout_prob self.max_seq = max_seq self.batch_size = batch_size self.epoch_count = epoch_count self.seed = seed self.learn_rate = learn_rate self.early_stopping_patience = early_stopping_patience self.max_grad_norm = max_grad_norm self.weight_decay = weight_decay self.adam_epsilon = adam_epsilon self.warmup_steps = warmup_steps self.train_by_log_softmax = train_by_log_softmax self.training_log_softmax_weight = training_log_softmax_weight self.training_softmax_temperature = training_softmax_temperature self.balance_batch_mode = balance_batch_mode self.take_train_checkpoints = take_train_checkpoints self.train_checkpoint_interval = train_checkpoint_interval self.check_early_stopping = check_early_stopping ELib.PASS()
def set_module_learning_rate(self, module, lr): if isinstance(module, nn.ModuleList): print(colored('>>> cannot handle ModuleList to set the LR in the optimizer! <<<', 'red')) sys.exit(1) params = list(module.parameters()) found = None for p_ind, cur_group in enumerate(self.optimizer.param_groups): try: if params[0] in cur_group['params']: found = cur_group break except: pass if found is not None: found['lr'] = lr found['initial_lr'] = lr self.scheduler.base_lrs[p_ind] = lr else: print(colored('>>> module was not found to set the LR in the optimizer! <<< \n' 'perhaps you have passed "customized_params" to setup_optimizer()', 'red')) sys.exit(1) ELib.PASS()
def __init__(self, config, sync_obj=None, **kwargs): # general properties self.config = config self.model_id = 0 self.current_train_epoch = -1 self.scheduler_overall_steps = -1 self.early_stopped_epoch = -1 self.train_loss_early_stopped_epoch = -1 self.sync_obj = sync_obj self.delay_optimizer = False self.delay_optimizer_loss = 0.0 self.custom_train_loss_func = None self.custom_test_loss_func = None self.removed_modules = list() self.init_seed(self.config.seed) # cls settings if self.config.cls_type == EBertCLSType.simple: self.bert_classifier = EBertClassifier.create(EBertClassifierSimple, self, self.config, **kwargs) else: self.bert_classifier = None self.tokenizer = BertTokenizer.from_pretrained(self.config.model_path) ELib.PASS()
def forward(self, input_batch, apply_softmax): b_output = self.bert_layer(input_batch['x'], attention_mask=input_batch['mask'], token_type_ids=input_batch['type']) last_hidden_states = b_output[0] output_pooled = b_output[1] self.output_vecs = np.copy( output_pooled.detach().to('cpu').numpy()).tolist() self.output_vecs_detail = np.copy( last_hidden_states.detach().to('cpu').numpy()).tolist() for cur_seq_ind, cur_seq_len in enumerate(input_batch['len']): self.output_vecs_detail[cur_seq_ind] = self.output_vecs_detail[ cur_seq_ind][:cur_seq_len] ELib.PASS() output_pooled = self.last_dropout_layer(output_pooled) logits = self.last_layer(output_pooled) if apply_softmax: logits = F.softmax(logits, dim=1) return [(EVar.DefaultTask, logits) ] # or [(input_batch['task_list'][0][0], logits)]
def add_module_to_optimizer(self, module): if isinstance(module, nn.ModuleList): print(colored('>>> cannot handle ModuleList to add to the optimizer! <<<', 'red')) sys.exit(1) lr = self.config.learn_rate weight_decay = self.config.weight_decay eps = self.config.adam_epsilon for cur_module_info in self.removed_modules: if cur_module_info[0] == module: lr = cur_module_info[1] weight_decay = cur_module_info[2] eps = cur_module_info[3] self.removed_modules.remove(cur_module_info) break self.optimizer.add_param_group( { 'params': module.parameters(), 'lr': lr, 'weight_decay': weight_decay, 'eps': eps } ) ELib.PASS()
def setup_logs(self, dir_path, curve_names, add_hooks=False, hook_interval=10): # one summary_writer will be created for each name in curve_names # also if add_hooks=True one summary_writer will be also created # for each module in self.hooked_modules if os.path.exists(dir_path): shutil.rmtree(dir_path) for cur_name in curve_names: self.logs[cur_name] = SummaryWriter( os.path.join(dir_path, cur_name)) if add_hooks: self.hook_interval = hook_interval self.hook_activated = True for name, module in self.hooked_modules.items(): self.logs[name] = SummaryWriter(os.path.join(dir_path, name)) self.hooksForward.append( module.register_forward_hook(self.__hook_forward)) self.hooksBackward.append( module.register_backward_hook(self.__hook_backward)) ELib.PASS()