예제 #1
0
 def setup_loss(self, train_config):
     criterion_config = train_config['criterion']
     org_term_config = criterion_config.get('org_term', dict())
     org_criterion_config = org_term_config.get('criterion', dict()) if isinstance(org_term_config, dict) else None
     self.org_criterion = None if org_criterion_config is None or len(org_criterion_config) == 0 \
         else get_single_loss(org_criterion_config)
     self.criterion = get_custom_loss(criterion_config)
     self.uses_teacher_output = False
     self.extract_org_loss = get_func2extract_org_output(criterion_config.get('func2extract_org_loss', None))
예제 #2
0
 def __init__(self, criterion_config):
     super().__init__()
     term_dict = dict()
     sub_terms_config = criterion_config.get('sub_terms', None)
     if sub_terms_config is not None:
         for loss_name, loss_config in sub_terms_config.items():
             sub_criterion_config = loss_config['criterion']
             sub_criterion = get_single_loss(sub_criterion_config, loss_config.get('params', None))
             term_dict[loss_name] = (sub_criterion, loss_config['factor'])
     self.term_dict = term_dict
 def setup_loss(self, train_config):
     criterion_config = train_config['criterion']
     org_term_config = criterion_config.get('org_term', dict())
     org_criterion_config = org_term_config.get(
         'criterion', dict()) if isinstance(org_term_config, dict) else None
     self.org_criterion = None if org_criterion_config is None or len(org_criterion_config) == 0 \
         else get_single_loss(org_criterion_config)
     self.criterion = get_custom_loss(criterion_config)
     logger.info(self.criterion)
     self.uses_teacher_output = \
         self.org_criterion is not None and isinstance(self.org_criterion, tuple(ORG_LOSS_LIST))
     self.extract_org_loss = get_func2extract_org_output(
         criterion_config.get('func2extract_org_loss', None))