def __update_label_info(unlabeled_bundle, lbls_list, soft_lbl_list,
                         do_softmax, temperature):
     if len(lbls_list) == 1:
         for tw_ind, _ in enumerate(unlabeled_bundle.tws):
             unlabeled_bundle.input_y[0][tw_ind] = lbls_list[0][tw_ind]
             if do_softmax:
                 unlabeled_bundle.input_y_row[0][tw_ind] = \
                     F.softmax(torch.tensor(soft_lbl_list[0][tw_ind]) / temperature).numpy().tolist()
             else:
                 unlabeled_bundle.input_y_row[0][tw_ind] = soft_lbl_list[0][
                     tw_ind]
     elif len(lbls_list) == 2:
         for tw_ind, _ in enumerate(unlabeled_bundle.tws):
             if do_softmax:
                 y_row = F.softmax(torch.tensor(soft_lbl_list[0][tw_ind]) / temperature).numpy() + \
                         F.softmax(torch.tensor(soft_lbl_list[1][tw_ind]) / temperature).numpy()
             else:
                 y_row = np.array(soft_lbl_list[0][tw_ind]) + np.array(
                     soft_lbl_list[1][tw_ind])
             y_row = (y_row / 2).tolist()
             unlabeled_bundle.input_y_row[0][tw_ind] = y_row
             unlabeled_bundle.input_y[0][tw_ind] = y_row.index(max(y_row))
         ELib.PASS()
     else:
         raise Exception('not implemented function!')
     ELib.PASS()
Exemple #2
0
 def train(self, train_bundle_list, valid_bundle_list = None, weighted_instance_loss = False,
           input_mode = EInputListMode.sequential, setup_learning_tools=True,
           extra_scheduled_trainset_size=0, extra_scheduled_epochs=0, customized_optimizer_params=None,
           report_number_of_intervals=20, switch_on_train_mode=True, train_shuffle=True, train_drop_last=True,
           balance_batch_mode_list=None, minimum_train_loss=None):
     if len(train_bundle_list) == 1 and len(train_bundle_list[0].tws) < self.config.batch_size:
         return
     ## init
     self.bert_classifier.to(self.config.device)
     self.bert_classifier.zero_grad()
     if setup_learning_tools:
         ## caveat: if you have called train() before this will reset the learning rate and the scheduler!
         self.__setup_learning_tools(train_bundle_list, weighted_instance_loss, input_mode,
                                     extra_scheduled_trainset_size, extra_scheduled_epochs,
                                     customized_optimizer_params)
     train_dt_list, train_tasks = self.get_datasets_and_tasks(train_bundle_list)
     valid_dt_list, valid_tasks = self.get_datasets_and_tasks(valid_bundle_list, self.config.early_stopping_patience)
     ## main loop
     self.early_stopped_epoch = -1
     self.train_loss_early_stopped_epoch = -1
     for cur_ep in range(math.ceil(self.config.epoch_count)):
         self.current_train_epoch = cur_ep
         self.bert_classifier.epoch_index += 1 # to track the overall number inside the classifier
         ## train
         if switch_on_train_mode:
             self.bert_classifier.train()
         else:
             self.bert_classifier.eval()
         self.__train_one_epoch(train_dt_list, train_tasks, input_mode, weighted_instance_loss,
                                report_number_of_intervals, train_shuffle, train_drop_last, balance_batch_mode_list)
         ## validation
         with torch.no_grad():
             stopping_valid_task = self.__validate_one_epoch(valid_bundle_list, valid_dt_list,
                                                             valid_tasks, weighted_instance_loss)
         ## post process epoch
         self.__print_epoch_results(cur_ep + 1, self.config.epoch_count, train_tasks, valid_tasks)
         if self.config.take_train_checkpoints and (cur_ep + 1) % self.config.train_checkpoint_interval == 0:
             print('saving checkpoint...')
             self.save(str(cur_ep + 1))
         if stopping_valid_task is not None:
             print('stopped early by \''+ stopping_valid_task[0] + '\'. restored the model of epoch {}'.
                   format(stopping_valid_task[1].learning_state.best_index + 1))
             self.early_stopped_epoch = stopping_valid_task[1].learning_state.best_index + 1
             break
         if minimum_train_loss is not None:
             for cur_task in train_tasks.items():
                 if cur_task[1].size > 0 and cur_task[1].loss <= minimum_train_loss:
                     self.train_loss_early_stopped_epoch = cur_ep
                     break
         if self.train_loss_early_stopped_epoch >= 0:
             break
         gc.collect()
         ELib.PASS()
     ## save it if needed
     if self.config.take_train_checkpoints and (cur_ep + 1) % self.config.train_checkpoint_interval != 0:
         print('saving checkpoint...')
         self.save(str(cur_ep + 1))
     self.bert_classifier.cpu()
     ELib.PASS()
     return train_tasks, valid_tasks
 def __get_query_vec(self, tokens, queries):
     result = [0] * len(tokens)
     for cur_q in queries:
         q_tokens = self.tokenizer.tokenize(cur_q)
         q_ind = self.__find_sublist(q_tokens, tokens)
         if q_ind >= 0:
             result[q_ind] = 1
         ELib.PASS()
     ELib.PASS()
     return result
 def run(cmd, per_query, train_path, valid_path_nullable,
         test_path_nullable, unlabeled_path_nullable, model_path,
         model_path_2, lm_model_path, t_lbl_path_1, t_lbl_path_2,
         output_dir, device, device_2, seed, train_sample,
         unlabeled_sample):
     cmd = EPretrainCMD.get(cmd)
     lc = ELblConf(
         0, 1,
         [ELbl(0, EVar.LblNonEventHealth),
          ELbl(1, EVar.LblEventHealth)
          ])  # mapping the dataset labels to negative and postivie
     if not os.path.exists(output_dir):
         os.makedirs(output_dir)
     queries = [None]
     if per_query:
         queries = ETweet.get_queries(
             ETweet.load(train_path, ELoadType.none))
     for q_ind, cur_query in enumerate(queries):
         if cur_query is not None:
             print('>>>>>>>> "' + cur_query + '" began')
         if cmd == EPretrainCMD.bert_reg:
             config = EBertConfig.get_config(cmd,
                                             EBertCLSType.none,
                                             model_path,
                                             model_path_2,
                                             lm_model_path,
                                             t_lbl_path_1,
                                             t_lbl_path_2,
                                             output_dir,
                                             5,
                                             device,
                                             device_2,
                                             seed,
                                             cur_query,
                                             gradient_checkpointing=False,
                                             check_early_stopping=False)
             train_bundle, valid_bundle, test_bundle, unlabeled_bundle = EInputBundle.get_data(
                 config.label_count, lc, train_path, valid_path_nullable,
                 test_path_nullable, unlabeled_path_nullable, cur_query)
             train_bundle = EPretrainProj.__stratified_sample_from_bundle(
                 train_bundle, lc, config.seed, train_sample)
             EPretrainProj.__sample_from_unlabeled(unlabeled_bundle,
                                                   unlabeled_sample,
                                                   config.seed)
             EPretrainProj.__run_twin_net(config, lc, cur_query,
                                          train_bundle, valid_bundle,
                                          test_bundle, unlabeled_bundle)
             ELib.PASS()
     ELib.PASS()
Exemple #5
0
 def close_logs(self):
     for cur_name, cur_log in self.logs.items():
         cur_log.close()
     for cur_h in self.hooksForward:
         cur_h.remove()
     for cur_h in self.hooksBackward:
         cur_h.remove()
     ELib.PASS()
 def __print_inst(self, tokens, tokens_len, tokens_ids, query_vec):
     tokens = ['['] + tokens + [']'] + ['?'] * (self.max_seq - tokens_len)
     for ind, _ in enumerate(tokens):
         print(
             str(ind) + ': ' + tokens[ind].ljust(7) + ' [' +
             str(tokens_ids[ind]).ljust(5) + ']' + '(' +
             str(query_vec[ind]) + ')')
     print()
     ELib.PASS()
Exemple #7
0
 def __init__(self, config):
     super(EBertClassifierSimple, self).__init__()
     self.config = config
     self.bert_layer = BertModel(self.config.bert_config)
     self.last_dropout_layer = nn.Dropout(self.config.dropout_prob)
     self.last_layer = nn.Linear(self.config.bert_config.hidden_size,
                                 self.config.label_count)
     self.apply(self._init_weights)
     ELib.PASS()
 def populate_bundle(bundle, lc, size, seed):
     size_needed = size - len(bundle.tws)
     ratio_needed = float(size_needed) / len(bundle.tws)
     tws = ETweet.random_stratified_sample(bundle.tws, lc, ratio_needed,
                                           seed, True)
     if len(tws) + len(bundle.tws) > size:
         tws = tws[:len(tws) - 1]
     EInputBundle.append(bundle, bundle, tws)
     ELib.PASS()
Exemple #9
0
 def __hook_backward(self, module, grad_input, grad_output):
     if self.hook_activated and self.train_step % self.hook_interval == 0:
         name = self.__get_module_name(module)
         self.logs[name].add_histogram(
             'out-gradients', grad_output[0].to('cpu').detach().numpy(),
             self.train_step)
         self.logs[name].add_histogram(
             'bias-gradients', grad_input[0].to('cpu').detach().numpy(),
             self.train_step)
     ELib.PASS()
 def __init__(self, task_list, input_x, input_y, input_y_row, queries,
              input_weight, input_meta, tws):
     self.task_list = task_list
     self.input_x = input_x
     self.input_y = input_y
     self.input_y_row = input_y_row
     self.queries = queries
     self.input_weight = input_weight
     self.input_meta = input_meta
     self.tws = tws
     ELib.PASS()
Exemple #11
0
 def __init__(self):
     super(EClassifier, self).__init__()
     self.logs = dict()
     self.train_state = 0
     self.epoch_index = -1
     self.train_step = -1
     self.hooked_modules = dict()
     self.hook_interval = None
     self.hook_activated = None
     self.hooksForward = list()
     self.hooksBackward = list()
     ELib.PASS()
Exemple #12
0
 def save(self, prefix=''):
     torch.save(self.bert_classifier.state_dict(),
                os.path.join(self.config.output_dir, prefix + 'pytorch_model.bin'))
     if os.path.join(self.config.model_path, 'config.json') != \
             os.path.join(self.config.output_dir, prefix + 'config.json'):
         shutil.copyfile(os.path.join(self.config.model_path, 'config.json'),
                         os.path.join(self.config.output_dir, prefix + 'config.json'))
     if os.path.join(self.config.model_path, 'vocab.txt') != \
             os.path.join(self.config.output_dir, prefix + 'vocab.txt'):
         shutil.copyfile(os.path.join(self.config.model_path, 'vocab.txt'),
                         os.path.join(self.config.output_dir, prefix + 'vocab.txt'))
     ELib.PASS()
Exemple #13
0
 def load(filePath, load_type, tweet_file=True):
     try:
         if type(load_type) == bool:
             print('change param to LoadType')
             exit(1)
         result = []
         file = open(filePath, "r", encoding="utf-8")
         lines = file.readlines()
         file.close()
         if load_type == ELoadType.pos_tagger:
             pass
         else:
             if tweet_file:
                 for ind, line in enumerate(lines):
                     tw = ETweet(line)
                     result.append(tw)
                     if (ind + 1) % 1000000 == 0:
                         print((ind + 1), ' reading lines')
                 if (ind + 1) > 1000000 and ELoadType.none != load_type:
                     print('reading tokens')
                 tags = None
                 if load_type == ELoadType.stored_tags or \
                         load_type == ELoadType.stored_tags_all_DontUseThis:
                     tags = EFeat1Gram.read_dep_tags(filePath + "-tags")
                 elif load_type == ELoadType.stored_injected_tags or \
                         load_type == ELoadType.stored_injected_tags_all_DontUseThis:
                     tags = EFeat1Gram.read_dep_tags(filePath +
                                                     "-tags-synthesized")
                 for ind, tw in enumerate(result):
                     if load_type == ELoadType.stored_tags_all_DontUseThis or \
                             load_type == ELoadType.stored_injected_tags_all_DontUseThis:
                         tw.ETokens = EFeat1Gram.convert_all_tags_to_tokens(
                             tags[ind])
                     elif load_type == ELoadType.stored_tags or \
                             load_type == ELoadType.stored_injected_tags:
                         tw.ETokens = EFeat1Gram.convert_tags_to_tokens(
                             tags[ind])
                     if (ind + 1
                         ) % 500000 == 0 and ELoadType.none != load_type:
                         print((ind + 1), ' constructing tokens')
             else:
                 for ind, line in enumerate(lines):
                     tokens = line.strip().split('\t')
                     tw = ETweet.__text_to_tweet_object(tokens)
                     result.append(tw)
                 ELib.PASS()
     except Exception as err:
         print(
             colored(
                 'Error in loading: "{}"\n\n{}'.format(filePath, str(err)),
                 'red'))
         sys.exit(1)
     return result
Exemple #14
0
 def __load_pretrained_bert_module(module, config, loaded):
     if type(module) is BertModel:
         if module not in loaded:
             module.load_state_dict(
                 EBertClassifier.__load_pretrained_bert_layer(
                     config).state_dict())
             loaded.add(module)
         else:
             print('already loaded')
         ELib.PASS()
     elif isinstance(module, EBertModelWrapper):
         if module not in loaded:
             module.bert_layer.load_state_dict(
                 EBertClassifier.__load_pretrained_bert_layer(
                     config).state_dict())
             loaded.add(module)
         else:
             print('already loaded')
         ELib.PASS()
     else:
         print(colored('unknown bert module to load', 'red'))
Exemple #15
0
 def __hook_forward(self, module, input, output):
     if self.hook_activated and self.train_step % self.hook_interval == 0:
         name = self.__get_module_name(module)
         self.logs[name].add_histogram('activations',
                                       output.to('cpu').detach().numpy(),
                                       self.train_step)
         self.logs[name].add_histogram(
             'weights',
             module.weight.data.to('cpu').detach().numpy(), self.train_step)
         self.logs[name].add_histogram(
             'bias',
             module.bias.data.to('cpu').detach().numpy(), self.train_step)
     ELib.PASS()
 def __sample_from_unlabeled(unlabeled_bundle, count, seed):
     random.seed(seed)
     count = min(count, len(unlabeled_bundle.tws))
     sample_set = set()
     while True:
         tw_ind = random.randint(0, len(unlabeled_bundle.tws) - 1)
         if unlabeled_bundle.tws[tw_ind] not in sample_set:
             sample_set.add(unlabeled_bundle.tws[tw_ind])
             if len(sample_set) >= count:
                 break
     drop_list = [tw for tw in unlabeled_bundle.tws if tw not in sample_set]
     EInputBundle.remove(unlabeled_bundle, drop_list)
     ELib.PASS()
 def get_data(label_count,
              lc,
              train_path_nullable,
              valid_path_nullable,
              test_path_nullable,
              unlabeled_path_nullable,
              filter_query=None,
              remove_unlabeled_test_tweets=False,
              tokenize_by_etokens=False,
              pivot_query=None,
              max_set_length=0):
     train_bundle, valid_bundle, test_bundle, unlabeled_bundle, = [None] * 4
     if train_path_nullable is not None:
         tws_train = ETweet.load(train_path_nullable,
                                 ELoadType.none,
                                 tweet_file=False)
         train_bundle = EInputBundle.get_input_bundle(
             [EVar.DefaultTask], tws_train, lc, filter_query,
             tokenize_by_etokens, pivot_query, label_count, max_set_length)
     if valid_path_nullable is not None:
         tws_valid = ETweet.load(valid_path_nullable,
                                 ELoadType.none,
                                 tweet_file=False)
         valid_bundle = EInputBundle.get_input_bundle(
             [EVar.DefaultTask], tws_valid, lc, filter_query,
             tokenize_by_etokens, pivot_query, label_count, max_set_length)
     if test_path_nullable is not None:
         tws_test = ETweet.load(test_path_nullable,
                                ELoadType.none,
                                tweet_file=False)
         if remove_unlabeled_test_tweets:
             print('removing unlabeled test tweets ...')
             ind = 0
             while ind < len(tws_test):
                 if tws_test[ind].Label == 0:
                     del tws_test[ind]
                     ELib.PASS()
                 else:
                     ind += 1
         test_bundle = EInputBundle.get_input_bundle(
             [EVar.DefaultTask], tws_test, lc, filter_query,
             tokenize_by_etokens, pivot_query, label_count, max_set_length)
     if unlabeled_path_nullable is not None:
         tws_unlabeled = ETweet.load(unlabeled_path_nullable,
                                     ELoadType.none,
                                     tweet_file=False)
         unlabeled_bundle = EInputBundle.get_input_bundle(
             [EVar.DefaultTask], tws_unlabeled, lc, filter_query,
             tokenize_by_etokens, pivot_query, label_count, max_set_length)
     return train_bundle, valid_bundle, test_bundle, unlabeled_bundle
 def __init__(self, seed, model_count, synchronized_bundle_indices=None):
     self.seed = seed
     self.model_count = model_count
     self.lock_dataset = threading.Lock()
     self.lock_batch = threading.Lock()
     self.lock_loss_calculation = threading.Lock()
     self.sync_bundle_indices = synchronized_bundle_indices
     self.sync_bundle_batches = dict()
     self.sync_bundle_batches_sizes = list()
     self.sync_counter = model_count  # this is needed for the first iteration
     self.sync_list = list()
     self.meta_list = list()
     for ind in range(model_count):
         self.sync_list.append(queue.Queue())
         self.meta_list.append(None)
     ELib.PASS()
 def __align_tokens_in_tweet(pivot, pivot_pos, bert_tokens_rec,
                             bert_tokens_rec_ind):
     span = 1
     while True:
         phrase = ''.join([
             entry[0] for entry in
             bert_tokens_rec[bert_tokens_rec_ind:bert_tokens_rec_ind + span]
         ])
         if pivot == phrase:
             return span, phrase
         if pivot_pos == 'U' and phrase == 'www':
             return span, 'www'
         if phrase == '[UNK]':
             return 1, pivot
         span += 1
         if len(bert_tokens_rec) < bert_tokens_rec_ind + span:
             return None
     ELib.PASS()
Exemple #20
0
 def save_tweets_as_text_file(start_id, tws, file_path):
     result = ''
     for cur_tw in tws:
         if cur_tw.Query != ETweet.tokenDummyQuery:
             query = cur_tw.Query
         else:
             query = ''
             if len(cur_tw.QueryList) > 0:
                 for q in cur_tw.QueryList:
                     query += '|' + q
             else:
                 query += '|'
         result += '{}\t{}\t{}\t{}\n'.format(
             str(start_id).zfill(7), str(cur_tw.Label), query,
             cur_tw.Text.strip())
         start_id += 1
     with io.open(file_path, 'w', encoding='utf-8') as ptr:
         ptr.write(result)
     ELib.PASS()
 def remove(bundle, to_remove_tws):
     id_dict = dict()
     for cur_ind, cur_tw in enumerate(bundle.tws):
         id_dict[cur_tw.Tweetid] = cur_ind
     to_delete = list()
     for cur_tw in to_remove_tws:
         cur_ind = id_dict[cur_tw.Tweetid]
         to_delete.append(cur_ind)
     to_delete.sort(reverse=True)
     for cur_ind in to_delete:
         del bundle.input_x[cur_ind]
         for y_ind in range(len(bundle.input_y)):
             del bundle.input_y[y_ind][cur_ind]
             del bundle.input_y_row[y_ind][cur_ind]
         del bundle.queries[cur_ind]
         del bundle.input_weight[cur_ind]
         del bundle.input_meta[cur_ind]
         del bundle.tws[cur_ind]
     ELib.PASS()
Exemple #22
0
 def remove_module_from_optimizer(self, module):
     if isinstance(module, nn.ModuleList):
         print(colored('>>> cannot handle ModuleList to delete from the optimizer! <<<', 'red'))
         sys.exit(1)
     params = list(module.parameters())
     found = None
     for cur_group in self.optimizer.param_groups:
         try:
             if params[0] in cur_group['params']:
                 found = cur_group
                 break
         except:
             pass
     if found is not None:
         self.removed_modules.append([module, found['lr'], found['weight_decay'], found['eps']])
         self.optimizer.param_groups.remove(found)
     else:
         print(colored('>>> module was not found for deletion in the optimizer! <<< \n'
                       'perhaps you have passed "customized_params" to setup_optimizer()', 'red'))
         sys.exit(1)
     ELib.PASS()
Exemple #23
0
 def load_pretrained_bert_modules(modules, config):
     mod_list = modules
     if type(modules) is not OrderedDict:
         mod_list = OrderedDict([('reconfig', modules)])
     loaded = set()
     for cur_module in mod_list.items():
         if type(cur_module[1]) is BertModel or isinstance(
                 cur_module[1], EBertModelWrapper):
             print('{}: '.format(cur_module[0]), end='', flush=True)
             EBertClassifier.__load_pretrained_bert_module(
                 cur_module[1], config, loaded)
         elif type(cur_module[1]) is nn.ModuleList:
             for c_ind, cur_child_module in enumerate(cur_module[1]):
                 if type(cur_child_module) is BertModel or isinstance(
                         cur_child_module, EBertModelWrapper):
                     print('{}[{}]: '.format(cur_module[0], c_ind),
                           end='',
                           flush=True)
                     EBertClassifier.__load_pretrained_bert_module(
                         cur_child_module, config, loaded)
     ELib.PASS()
Exemple #24
0
 def __train_one_epoch(self, train_dt_list, train_tasks, input_mode, weighted_instance_loss,
                       report_number_of_intervals, train_shuffle, train_drop_last, balance_batch_mode_list):
     batches = self.generate_batches(train_dt_list, self.config, train_shuffle, train_drop_last,
                                     self.current_train_epoch, input_mode, balance_batch_mode_list)
     [cur_task[1].reset() for cur_task in train_tasks.items()]
     for ba_ind, cur_batch in enumerate(batches):
         self.bert_classifier.train_step += 1  # to track the overall number inside the classifier
         while True:
             outcome = self.bert_classifier(cur_batch, False)
             self.__process_loss(outcome, cur_batch, train_tasks, True, weighted_instance_loss)
             if not self.delay_optimizer:
                 break
         if ELib.progress_made(ba_ind, cur_batch['batch_count'], report_number_of_intervals):
             print(ELib.progress_percent(ba_ind, cur_batch['batch_count']), end=' ', flush=True)
         self.delete_batch_from_gpu(cur_batch, input_mode)
         del cur_batch, outcome
         ## in case there are multiple models and their losses are heavy (in terms of memory)
         ## you can call 'self.sync_obj.lock_loss_calculation.acquire()' in 'self.custom_train_loss_func()'
         ## This way the losses are calculated one by one and after that the models are re-synched
         if self.sync_obj is not None and self.sync_obj.lock_loss_calculation.locked():
             ## wait for the other models to arrive
             if self.sync_obj.sync_counter == self.sync_obj.model_count:
                 self.sync_obj.reset()
             self.sync_obj.sync_counter += 1
             self.sync_obj.lock_loss_calculation.release()
             while self.sync_obj.sync_counter < self.sync_obj.model_count:
                 self.sleep()
         # pprint(vars(self))
         # ELib.PASS()
     ## if there are multiple models avoid double printing the newline
     if self.sync_obj is None:
         print()
     elif self.model_id == 0:
         print()
     ## calculate the metric averages in the epoch
     for cur_task in train_tasks.items():
         if cur_task[1].size > 0:
             cur_task[1].loss /= cur_task[1].size
             cur_task[1].f1 = ELib.calculate_f1(cur_task[1].lbl_true, cur_task[1].lbl_pred)
     ELib.PASS()
 def __init__(self, cmd, cls_type, bert_config, label_count, model_path,
              model_path_2, lm_model_path, t_lbl_path_1, t_lbl_path_2,
              output_dir, device, device_2, dropout_prob, max_seq,
              batch_size, epoch_count, seed, learn_rate,
              early_stopping_patience, max_grad_norm, weight_decay,
              adam_epsilon, warmup_steps, train_by_log_softmax,
              training_log_softmax_weight, training_softmax_temperature,
              balance_batch_mode, take_train_checkpoints,
              train_checkpoint_interval, check_early_stopping):
     self.cmd = cmd
     self.cls_type = cls_type
     self.bert_config = bert_config
     self.label_count = label_count
     self.model_path = model_path
     self.model_path_2 = model_path_2
     self.lm_model_path = lm_model_path
     self.t_lbl_path_1 = t_lbl_path_1
     self.t_lbl_path_2 = t_lbl_path_2
     self.output_dir = output_dir
     self.device = device
     self.device_2 = device_2
     self.dropout_prob = dropout_prob
     self.max_seq = max_seq
     self.batch_size = batch_size
     self.epoch_count = epoch_count
     self.seed = seed
     self.learn_rate = learn_rate
     self.early_stopping_patience = early_stopping_patience
     self.max_grad_norm = max_grad_norm
     self.weight_decay = weight_decay
     self.adam_epsilon = adam_epsilon
     self.warmup_steps = warmup_steps
     self.train_by_log_softmax = train_by_log_softmax
     self.training_log_softmax_weight = training_log_softmax_weight
     self.training_softmax_temperature = training_softmax_temperature
     self.balance_batch_mode = balance_batch_mode
     self.take_train_checkpoints = take_train_checkpoints
     self.train_checkpoint_interval = train_checkpoint_interval
     self.check_early_stopping = check_early_stopping
     ELib.PASS()
Exemple #26
0
 def set_module_learning_rate(self, module, lr):
     if isinstance(module, nn.ModuleList):
         print(colored('>>> cannot handle ModuleList to set the LR in the optimizer! <<<', 'red'))
         sys.exit(1)
     params = list(module.parameters())
     found = None
     for p_ind, cur_group in enumerate(self.optimizer.param_groups):
         try:
             if params[0] in cur_group['params']:
                 found = cur_group
                 break
         except:
             pass
     if found is not None:
         found['lr'] = lr
         found['initial_lr'] = lr
         self.scheduler.base_lrs[p_ind] = lr
     else:
         print(colored('>>> module was not found to set the LR in the optimizer! <<< \n'
                       'perhaps you have passed "customized_params" to setup_optimizer()', 'red'))
         sys.exit(1)
     ELib.PASS()
Exemple #27
0
 def __init__(self, config, sync_obj=None, **kwargs):
     # general properties
     self.config = config
     self.model_id = 0
     self.current_train_epoch = -1
     self.scheduler_overall_steps = -1
     self.early_stopped_epoch = -1
     self.train_loss_early_stopped_epoch = -1
     self.sync_obj = sync_obj
     self.delay_optimizer = False
     self.delay_optimizer_loss = 0.0
     self.custom_train_loss_func = None
     self.custom_test_loss_func = None
     self.removed_modules = list()
     self.init_seed(self.config.seed)
     # cls settings
     if self.config.cls_type == EBertCLSType.simple:
         self.bert_classifier = EBertClassifier.create(EBertClassifierSimple, self, self.config, **kwargs)
     else:
         self.bert_classifier = None
     self.tokenizer = BertTokenizer.from_pretrained(self.config.model_path)
     ELib.PASS()
Exemple #28
0
    def forward(self, input_batch, apply_softmax):
        b_output = self.bert_layer(input_batch['x'],
                                   attention_mask=input_batch['mask'],
                                   token_type_ids=input_batch['type'])
        last_hidden_states = b_output[0]
        output_pooled = b_output[1]

        self.output_vecs = np.copy(
            output_pooled.detach().to('cpu').numpy()).tolist()
        self.output_vecs_detail = np.copy(
            last_hidden_states.detach().to('cpu').numpy()).tolist()
        for cur_seq_ind, cur_seq_len in enumerate(input_batch['len']):
            self.output_vecs_detail[cur_seq_ind] = self.output_vecs_detail[
                cur_seq_ind][:cur_seq_len]
            ELib.PASS()

        output_pooled = self.last_dropout_layer(output_pooled)
        logits = self.last_layer(output_pooled)
        if apply_softmax:
            logits = F.softmax(logits, dim=1)
        return [(EVar.DefaultTask, logits)
                ]  # or [(input_batch['task_list'][0][0], logits)]
Exemple #29
0
 def add_module_to_optimizer(self, module):
     if isinstance(module, nn.ModuleList):
         print(colored('>>> cannot handle ModuleList to add to the optimizer! <<<', 'red'))
         sys.exit(1)
     lr = self.config.learn_rate
     weight_decay = self.config.weight_decay
     eps = self.config.adam_epsilon
     for cur_module_info in self.removed_modules:
         if cur_module_info[0] == module:
             lr = cur_module_info[1]
             weight_decay = cur_module_info[2]
             eps = cur_module_info[3]
             self.removed_modules.remove(cur_module_info)
             break
     self.optimizer.add_param_group(
         {
             'params': module.parameters(),
             'lr': lr,
             'weight_decay': weight_decay,
             'eps': eps
         }
     )
     ELib.PASS()
Exemple #30
0
 def setup_logs(self,
                dir_path,
                curve_names,
                add_hooks=False,
                hook_interval=10):
     # one summary_writer will be created for each name in curve_names
     # also if add_hooks=True one summary_writer will be also created
     # for each module in self.hooked_modules
     if os.path.exists(dir_path):
         shutil.rmtree(dir_path)
     for cur_name in curve_names:
         self.logs[cur_name] = SummaryWriter(
             os.path.join(dir_path, cur_name))
     if add_hooks:
         self.hook_interval = hook_interval
         self.hook_activated = True
         for name, module in self.hooked_modules.items():
             self.logs[name] = SummaryWriter(os.path.join(dir_path, name))
             self.hooksForward.append(
                 module.register_forward_hook(self.__hook_forward))
             self.hooksBackward.append(
                 module.register_backward_hook(self.__hook_backward))
     ELib.PASS()