Ejemplo n.º 1
0
 def test(self, test_bundle, return_output_vecs=False, weighted_instance_loss=False,
          print_perf=True, title=None, report_number_of_intervals=20, return_output_vecs_get_details=True):
     if len(test_bundle.task_list) > 1:
         print('only one task is allowed for testing')
         return None
     if len(test_bundle.tws) == 0:
         return list(), list(), list(), list()
     if title is None:
         title = ''
     else:
         title += ' '
     self.bert_classifier.to(self.config.device)
     self.bert_classifier.zero_grad()
     self.bert_classifier.eval()
     self.setup_objective(weighted_instance_loss)
     test_dt = EBertDataset(test_bundle, self.tokenizer, self.config.max_seq)
     batches = self.generate_batches([test_dt], self.config, False, False, 0, EInputListMode.sequential)
     result_vecs = list()
     result_vecs_detail = list()
     tasks = {test_bundle.task_list[0] : ETaskState(test_bundle.task_list[0])}
     print(title + 'labeling ', end=' ', flush=True)
     with torch.no_grad():
         for ba_ind, cur_batch in enumerate(batches):
             outcome = self.bert_classifier(cur_batch, False)
             self.__process_loss(outcome, cur_batch, tasks, False, weighted_instance_loss)
             if return_output_vecs:
                 result_vecs.extend(self.bert_classifier.output_vecs)
                 if self.bert_classifier.output_vecs_detail is not None and return_output_vecs_get_details:
                     result_vecs_detail.extend(self.bert_classifier.output_vecs_detail)
             if ELib.progress_made(ba_ind, cur_batch['batch_count'], report_number_of_intervals):
                 print(ELib.progress_percent(ba_ind, cur_batch['batch_count']), end=' ', flush=True)
             self.delete_batch_from_gpu(cur_batch, EInputListMode.sequential)
             del cur_batch, outcome
     print()
     task_out = tasks[test_bundle.task_list[0]]
     task_out.loss /= task_out.size
     perf = ELib.calculate_metrics(task_out.lbl_true, task_out.lbl_pred)
     if print_perf:
         print('Test Results L1> Loss: {:.3f} F1: {:.3f} Pre: {:.3f} Rec: {:.3f}'.format(
             task_out.loss, perf[0], perf[1], perf[2]) + '\t\t' + ELib.get_time())
     self.bert_classifier.cpu()
     return task_out.lbl_pred, task_out.logits, [result_vecs, result_vecs_detail], perf
Ejemplo n.º 2
0
 def __train_one_epoch(self, train_dt_list, train_tasks, input_mode, weighted_instance_loss,
                       report_number_of_intervals, train_shuffle, train_drop_last, balance_batch_mode_list):
     batches = self.generate_batches(train_dt_list, self.config, train_shuffle, train_drop_last,
                                     self.current_train_epoch, input_mode, balance_batch_mode_list)
     [cur_task[1].reset() for cur_task in train_tasks.items()]
     for ba_ind, cur_batch in enumerate(batches):
         self.bert_classifier.train_step += 1  # to track the overall number inside the classifier
         while True:
             outcome = self.bert_classifier(cur_batch, False)
             self.__process_loss(outcome, cur_batch, train_tasks, True, weighted_instance_loss)
             if not self.delay_optimizer:
                 break
         if ELib.progress_made(ba_ind, cur_batch['batch_count'], report_number_of_intervals):
             print(ELib.progress_percent(ba_ind, cur_batch['batch_count']), end=' ', flush=True)
         self.delete_batch_from_gpu(cur_batch, input_mode)
         del cur_batch, outcome
         ## in case there are multiple models and their losses are heavy (in terms of memory)
         ## you can call 'self.sync_obj.lock_loss_calculation.acquire()' in 'self.custom_train_loss_func()'
         ## This way the losses are calculated one by one and after that the models are re-synched
         if self.sync_obj is not None and self.sync_obj.lock_loss_calculation.locked():
             ## wait for the other models to arrive
             if self.sync_obj.sync_counter == self.sync_obj.model_count:
                 self.sync_obj.reset()
             self.sync_obj.sync_counter += 1
             self.sync_obj.lock_loss_calculation.release()
             while self.sync_obj.sync_counter < self.sync_obj.model_count:
                 self.sleep()
         # pprint(vars(self))
         # ELib.PASS()
     ## if there are multiple models avoid double printing the newline
     if self.sync_obj is None:
         print()
     elif self.model_id == 0:
         print()
     ## calculate the metric averages in the epoch
     for cur_task in train_tasks.items():
         if cur_task[1].size > 0:
             cur_task[1].loss /= cur_task[1].size
             cur_task[1].f1 = ELib.calculate_f1(cur_task[1].lbl_true, cur_task[1].lbl_pred)
     ELib.PASS()