예제 #1
0
def get_training_info_dict(filepath):
    if not filepath.endswith('.pkl'):
        filepath += '.pkl'
    if not os.path.exists(filepath):
        print("Could not open file: ", filepath)
        raise AttributeError
    return trainer.PklFile(filepath).store.copy()
예제 #2
0
    def run_training(self,
                     training_output_path,
                     summaries_path,
                     verbose=False):
        """
        Runs the training on data loaded and prepared in the constructor, according to training params
        specified in the constructor
        """

        self.output_filename = self.get_filename(summaries_path=summaries_path)
        self.training_output_path = training_output_path + self.output_filename

        print("\n\nTraining the model")
        print("Filename: ", self.training_output_path)
        print("Number of training samples: ",
              len(self.train_data_normalized.data))
        print("Number of validation samples: ",
              len(self.validation_data_normalized.data))

        if verbose:
            self.model.summary()
            print("\nTraining params:")
            for arg in self.training_params:
                print((arg, ":", self.training_params[arg]))

        self.start_timestamp = datetime.datetime.now()

        trainer.Trainer(self.training_output_path, verbose=verbose).train(
            x_train=self.train_data_normalized.data,
            x_test=self.validation_data_normalized.data,
            y_train=self.train_data_normalized.data,
            y_test=self.validation_data_normalized.data,
            model=self.model,
            force=True,
            use_callbacks=True,
            verbose=int(verbose),
            **self.training_params,
            output_path=self.training_output_path)

        self.end_timestamp = datetime.datetime.now()
        print("Training executed in: ",
              (self.end_timestamp - self.start_timestamp), " s")
예제 #3
0
    args = parser.parse_args()

    # Logger
    FORMAT = '%(asctime)-15s %(message)s'
    logging.basicConfig(format=FORMAT, level=args.log_level)
    logger = logging.getLogger('global_logger')

    # Main
    with open(args.config, 'r') as config_file:
        try:
            config = yaml.safe_load(config_file)

            # Preprocessing
            preprocessor = Preprocessor(config=config['preprocessing'],
                                        logger=logger)
            preprocessor.generate_data_loaders()

            # Training
            trainer = Trainer(config=config['training'],
                              logger=logger,
                              preprocessor=preprocessor)
            trainer.kfold_training()

            # Predicting
            predictor = Predictor(config=config['predict'],
                                  logger=logger,
                                  preprocessor=preprocessor)
            predictor.predict()
            predictor.save_result()
        except yaml.YAMLError as err:
            logger.warning(f'Config file err: {err}')
예제 #4
0
from module import Preprocessor, Trainer, Predictor

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Process commandline')
    parser.add_argument('--config', type=str, required=True)
    parser.add_argument('--log_level', type=str, default="INFO")
    args = parser.parse_args()

    FORMAT = '%(asctime)-15s %(message)s'
    logging.basicConfig(format=FORMAT, level=args.log_level)
    logger = logging.getLogger('global_logger')
    with open(args.config, 'r') as config_file:
        try:
            config = yaml.safe_load(config_file)
            preprocessor = Preprocessor(config['preprocessing'], logger)
            data_x, data_y, train_x, train_y, validate_x, validate_y, test_x = preprocessor.process(
            )
            trainer = Trainer(config['training'], logger, preprocessor.classes)
            trainer.fit(train_x, train_y)
            # Returns value error: Classification metrics can't handle a mix of multilabel-indicator and binary target
            # Therefore quote out these for now to have the program work
            #accuracy, cls_report = trainer.validate(validate_x, validate_y)
            #logger.info("accuracy:{}".format(accuracy))
            #logger.info("\n{}\n".format(cls_report))
            model = trainer.fit(data_x, data_y)
            predictor = Predictor(config['predict'], logger, model)
            probs = predictor.predict_prob(test_x)
            predictor.save_result(preprocessor.test_ids, probs)
        except yaml.YAMLError as err:
            logger.warning('Config file err: {}'.format(err))
예제 #5
0
    if not os.path.exists(log_folder):
        os.mkdir(log_folder)

    fileHandler = logging.FileHandler(log_folder + filename, 'w', 'utf-8')
    fileHandler.setFormatter(formatter)
    logger.addHandler(fileHandler)

    with open(args.config, 'r') as config_file:
        try:
            print('Preprocessing...')
            config = yaml.safe_load(config_file)
            preprocessor = Preprocessor(config['preprocessing'], logger)
            _, _, train_x, train_y, validate_x, validate_y, test_x = preprocessor.process()

            if config['training']['model_name'] != 'naivebayse':
                config['training']['vocab_size'] = len(preprocessor.word2ind.keys())

            print('Training...')
            pretrained_embedding = preprocessor.embedding_matrix if config['preprocessing'].get('pretrained_embedding', None) else None
            trainer = Trainer(config['training'], logger, preprocessor.classes, pretrained_embedding)
            model, accuracy, cls_report, history = trainer.fit_and_validate(train_x, train_y, validate_x, validate_y)
            logger.info("accuracy:{}".format(accuracy))
            logger.info("\n{}\n".format(cls_report))

            print('Predicting...')
            predictor = Predictor(config['predict'], logger, model)
            probs = predictor.predict_prob(test_x)
            predictor.save_result(preprocessor.test_ids, probs)
        except yaml.YAMLError as err:
            logger.warning('Config file error: {}'.format(err))
    # Preparing optimizer
    optimizer = create_optimizer(args, model)

    # Preparing criterions
    criterions = {
        "class_1": create_criterions(args, num_class_1, device),
        "class_2": create_criterions(args, num_class_2, device),
    }

    # Preparing metrics
    metrics = {
        "class_1": create_metrics(args),
        "class_2": create_metrics(args)
    }

    # Preparing trainer
    trainer = Trainer(
        model=model,
        optim=optimizer,
        criterions=criterions,
        metric=metrics,
        scheduler=None,
        train_dl=data["train_dataloader"],
        val_dl=data["valid_dataloader"],
        writer=writer,
        save_dir=args.model_dir,
        device=device,
    )
    trainer.fit(args.epochs)