def __init__(self, dataset: str, repr: str, model: str, backwards: bool, classifier_type: str): fs = FS.for_classifier(dataset, repr, model, PretrainingType.FULL, classifier_type) text_field = fs.load_text_field() super().__init__(repr=repr, fs=fs, text_field=text_field, config_class=ClassifierTrainingConfig, output_field=LEVEL_LABEL, n_predictions=6 if classifier_type == 'level' else 2, backwards=backwards)
def run_on_device(config: ClassifierConfig, force_rerun: bool) -> None: base_model = config.base_model pretraining = config.pretraining_type PrepConfig.assert_classification_config(config.data.repr) if bool(base_model) != bool(pretraining): raise ValueError( 'Base model and pretraining_type params must be both set or both unset!' ) fs = FS.for_classifier(config.data.dataset, config.data.repr, base_model=base_model, pretraining=pretraining, classification_type=config.classification_type) fs.create_path_to_model(config.data, config.training_config) attach_dataset_aware_handlers_to_loggers(fs.path_to_model, 'main.log') print_gpu_info() text_field = fs.load_text_field() rnn_learner = create_nn_architecture(fs, text_field, LEVEL_LABEL, config.data, config.arch, config.min_log_coverage_percent) logger.info(rnn_learner) same_model_exists = fs.best_model_exists(rnn_learner) if same_model_exists and not force_rerun: logger.info( f'Model {fs.path_to_classification_model} already trained. Not rerunning training.' f'To retrain the model with this parameters, specify --force-rerun flag' ) return elif same_model_exists: logger.info( f"Model {fs.path_to_classification_model} already trained. Forcing rerun." ) if pretraining == PretrainingType.FULL: try: logger.info(f'Trying to load base classifier: {base_model}') fs.load_base_model(rnn_learner) logger.info('Base classifier model is loaded.') except Exception as e: logger.warning(e) logger.warning( 'Base classifier model not loaded. Training from scratch') elif pretraining == PretrainingType.ONLY_ENCODER: try: logger.info(f'Trying to load pretarined LM: {base_model}') # TODO its a dirty hack. fix it fs.lm_cl_pretraining = True fs.load_pretrained_langmodel(rnn_learner) logger.info("Using pretrained LM") except Exception as e: logger.warning(e) logger.warning('Pretrained LM not loaded. Training from scratch') else: logger.info("No pretraining. Training classifier from scratch.") config_manager.save_config(config.training_config, fs.path_to_model) train(fs, rnn_learner, config.training, config.metrics) model = rnn_learner.model to_test_mode(model) sample_test_runs_file = os.path.join(fs.path_to_model, 'test_runs.out') n_predicitions = 6 if config.classification_type == 'level' else 2 show_tests(fs.test_path, model, text_field, sample_test_runs_file, config.data.backwards, n_predicitions, config.testing.n_samples) logger.info("Classifier training finished successfully.")