Пример #1
0
def eval_main(checkpoint):

    config = checkpoint['config']
    data_config = config['data']

    tsf = _get_transform(config, 'val')

    data_manager = getattr(data_module, config['data']['type'])(config['data'])
    test_loader = data_manager.get_loader('test', tsf)

    m_name, sd, classes = _get_model_att(checkpoint)
    model = getattr(net_module, m_name)(classes, state_dict=sd)

    model.load_state_dict(checkpoint['state_dict'])

    num_classes = len(classes)
    metrics = getattr(net_module, config['metrics'])(num_classes)

    evaluation = ClassificationEvaluator(test_loader, model)
    ret = evaluation.evaluate(metrics)
    print(ret)
    return ret
Пример #2
0
    def execute(self):
        def capitalize(s):
            return s[0].upper() + s[1:]

        scenario = self.config.get_scenario()
        print(
            f"******* starting pipeline {self.config['pipeline']['cur_scenario']} ******"
        )
        for current_step in self.config.get_scenario():
            print(
                f" =================== Current step: {current_step} ============================="
            )
            self.config.set_current_pipeline_step(current_step)
            self.pipeline_config = self.config.pipeline_config

            logger.info('******** LOADING DATA *******')
            self.get_data_loader()

            # Training phase
            logger.info('******* TRAINING *******')
            print("get task type =========== ", self.config.get_task_type())
            if not 'disable' in self.pipeline_config['train'].keys() or \
                    self.pipeline_config['train']['disable'] is not True:
                train_phase = self.pipeline_config['train']
                train_module = __import__('train')
                converter_module = __import__('converter')
                converter_class = getattr(
                    converter_module,
                    capitalize(train_phase['converter']) + 'Converter')
                if self.config.get_task_type() == 'sequence':
                    trainer_class_name = 'SpanTrainer'
                else:
                    trainer_class_name = 'ClassificationTrainer'
                trainer_class = getattr(train_module, trainer_class_name)
                trainer = trainer_class(self.dataloader, converter_class())
                trainer.train()
            else:
                logger.info(">>> training disabled")
                train_phase = self.pipeline_config['train']
                train_module = __import__('train')
                converter_module = __import__('converter')
                converter_class = getattr(
                    converter_module,
                    capitalize(train_phase['converter']) + 'Converter')

            # Evaluation phase
            logger.info('******* EVALUATION *******')
            if not 'disable' in self.pipeline_config['eval'].keys() or \
                    self.pipeline_config['eval']['disable'] is not True:

                if self.config.get_task_type() == 'sequence':
                    #evaluator = SequenceEvaluator(self.dataloader)
                    evaluator = SpanEvaluator(self.dataloader)
                else:
                    evaluator = ClassificationEvaluator(self.dataloader)

                evaluator.evaluate(converter_class())

                # Create a score file for evaluation
                if self.config.get_task_type() != 'sequence':
                    from score import score_task1
                    pred_file = self.config.get_output_data_dir(
                    ) + 'predict1.txt'
                    gold_file = self.config.get_output_data_dir() + 'true1.txt'
                    score_task1(predict_file=pred_file, true_file=gold_file)

                    # Run the scorer
                    sys.path.append(
                        os.path.realpath('SEMEVAL-2021-task6-corpus'))

                    from scorer.task1_3 import evaluate, validate_files  # (pred_fpath, gold_fpath, CLASSES):
                    # from format_checker.task1_3 import validate_files
                    CLASSES = read_labels_from_file(
                        self.pipeline_config['data']['labels'])

                    if validate_files(pred_file, gold_file, CLASSES):
                        logger.info('Prediction file format is correct')
                        macro_f1, micro_f1 = evaluate(pred_file, gold_file,
                                                      CLASSES)
                        logger.info("macro-F1={:.5f}\tmicro-F1={:.5f}".format(
                            macro_f1, micro_f1))
                    else:
                        print("Failed to validate prediction & gold files")

                else:
                    print("No scoring for sequence type")

            else:
                print("Evaluation is disabled")

            # Post-processing phase
            if 'postprocess' in self.pipeline_config.keys() and \
                    self.pipeline_config['postprocess']['disable'] is not True:
                logger.info('******* POST-PROCESSING *******')
                postprocess_phase = self.pipeline_config['postprocess']
                postprocess_module = __import__('postprocess')
                postprocess_class = getattr(
                    postprocess_module,
                    postprocess_phase['processor']['class'])
                postprocessor = postprocess_class(
                    self.config.get_model_type(),
                    self.config.get_output_dir(),
                    **postprocess_phase[
                        'processor']  # This should return a dictionary
                )
                postprocessor.execute()

            else:
                print('No post-processing defined')