Beispiel #1
0
    def load_pretrained_langmodel(self, rnn_learner: RNN_Learner) -> None:
        if self.is_lang_model:
            raise AssertionError("This method can only be called for classifiers.")

        logger.debug(f"Copying {self.path_to_base_torch_encoder} to {self.path_to_new_encoder_location}")
        shutil.copy(self.path_to_base_torch_encoder, self.path_to_new_encoder_location)
        rnn_learner.load_encoder(LM_ENCODER_NAME)
Beispiel #2
0
def find_and_plot_lr(rnn_learner: RNN_Learner, fs: FS):
    logger.info("Looking for the best learning rate...")
    # TODO we shouldnt pass file argument, when looking
    # for learning rate we should log to console
    rnn_learner.lr_find(
        file=open(os.path.join(fs.path_to_model, 'training.log'), 'w'))

    rnn_learner.sched.plot(fs.path_to_lr_plot)
    logger.info(f"Plot is saved to {fs.path_to_lr_plot}")
Beispiel #3
0
def gen_text(learner: RNN_Learner, starting_words_list: List[str], how_many_to_gen: int) -> List[str]:
    text = []
    t = to_gpu(learner.text_field.numericalize([starting_words_list], -1))
    res, *_ = learner.model(t)
    for i in range(how_many_to_gen):
        n = torch.multinomial(res[-1].exp(), 1)
        # n = n[1] if n.data[0] == 0 else n[0]
        text.append(learner.text_field.vocab.itos[n.data[0]])
        res, *_ = learner.model(n[0].unsqueeze(0))
    return text
Beispiel #4
0
 def get_transformer_model(self, opt_fn, emb_sz, max_seq_len, **kwargs):
     m = get_transformer_language_model(self.n_tok,
                                        max_seq_len,
                                        self.trn_dl.target_length,
                                        emb_sz,
                                        pad_token=self.pad_idx,
                                        **kwargs)
     model = TransformerLanguageModel(to_gpu(m))
     return RNN_Learner(self, model, opt_fn=opt_fn)
Beispiel #5
0
def train_and_save_model(rnn_learner: RNN_Learner, fs: FS,
                         training: LMTraining, metric_list: List[str],
                         cache: Cache, use_subword_aware_metrics: bool):
    only_validation = False
    n = training.cycle.n
    if training.cycle.n == 0:
        logger.info("Number of epochs specified is 0. Not training...")
        fs.save_best(rnn_learner)
        only_validation = True
        n = 1

    training_start_time = time()
    training_log_file = os.path.join(fs.path_to_model, 'training.log')
    logger.info(
        f"Starting training, check {training_log_file} for training progress")
    callbacks = []

    if training.early_stop:
        callbacks.append(
            EarlyStopping(rnn_learner,
                          save_path=BEST_MODEL_NAME,
                          best_loss_path=BEST_LOSS_FILENAME,
                          best_acc_path=BEST_ACC_FILENAME,
                          best_epoch_path=BEST_EPOCH_FILENAME,
                          enc_path=ENCODER_NAME))

    validation_function = get_validation_function(cache,
                                                  use_subword_aware_metrics,
                                                  rnn_learner.text_field)
    vals, ep_vals = rnn_learner.fit(lrs=training.lr,
                                    n_cycle=n,
                                    wds=training.wds,
                                    cycle_len=training.cycle.len,
                                    cycle_mult=training.cycle.mult,
                                    metrics=list(
                                        map(lambda x: getattr(metrics, x),
                                            metric_list)),
                                    get_ep_vals=True,
                                    file=open(training_log_file, 'w'),
                                    callbacks=callbacks,
                                    valid_func=validation_function,
                                    only_validation=only_validation)
    training_time_mins = int(time() - training_start_time) // 60
    with open(os.path.join(fs.path_to_model, 'results.out'), 'w') as f:
        f.write(str(training_time_mins) + "\n")
        for _, vals in ep_vals.items():
            f.write(" ".join(map(lambda x: str(x), vals)) + "\n")
Beispiel #6
0
def train(fs: FS, rnn_learner: RNN_Learner, training: ClassifierTraining,
          metric_list: List[str]):
    training_log_file = os.path.join(fs.path_to_model, 'training.log')
    if not training.stages:
        logger.warning("No stages specified in the config")
        return
    logger.info(
        f"Starting training, check {training_log_file} for training progress")
    for idx, stage in enumerate(training.stages):
        training_start_time = time()
        logger.info(f'----- Running stage {idx}')
        cycle = stage.cycle
        only_validation = False
        n = cycle.n
        if cycle.n == 0 or cycle.len == 0:
            logger.warning(
                "Number of epochs specified at this stage is 0. Not training..."
            )
            only_validation = True
            n = 1

        callbacks = []

        if training.early_stop:
            name_suffix = f'.{idx}' if idx < len(training.stages) - 1 else ''
            callbacks.append(
                EarlyStopping(
                    rnn_learner,
                    save_path=BEST_MODEL_NAME + name_suffix,
                    best_loss_path=BEST_LOSS_FILENAME + name_suffix,
                    best_acc_path=BEST_ACC_FILENAME + name_suffix,
                    best_epoch_path=BEST_EPOCH_FILENAME + name_suffix,
                ))

        rnn_learner.freeze_to(stage.freeze_to)
        lrs = training.lrs
        lr_list = [lrs.base_lr / lrs.factor**m for m in lrs.multipliers]
        logger.debug(f'Using the following learning rates: {lr_list}')
        vals, ep_vals = rnn_learner.fit(lrs=lr_list,
                                        metrics=list(
                                            map(lambda x: getattr(metrics, x),
                                                metric_list)),
                                        wds=training.wds,
                                        cycle_len=cycle.len,
                                        n_cycle=n,
                                        cycle_mult=cycle.mult,
                                        callbacks=callbacks,
                                        get_ep_vals=True,
                                        file=open(training_log_file, 'w+'),
                                        only_validation=only_validation)
        training_time_mins = int(time() - training_start_time) // 60
        with open(os.path.join(fs.path_to_model, 'results.out'), 'a+') as f:
            f.write(str(training_time_mins) + "\n")
            for _, vals in ep_vals.items():
                f.write(" ".join(map(lambda x: str(x), vals)) + "\n")

    # logger.info(f'Current accuracy is ...')
    # logger.info(f'                    ... {accuracy_gen(*rnn_learner.predict_with_targs())}')
    # rnn_learner.sched.plot_loss()

    return rnn_learner
Beispiel #7
0
 def load_base_model(self, rnn_learner: RNN_Learner) -> None:
     logger.debug(f"Copying from {self.path_to_base_torch_model} to {self.path_to_new_torch_model_location}")
     shutil.copy(self.path_to_base_torch_model, self.path_to_new_torch_model_location)
     rnn_learner.load(BEST_BASE_MODEL_NAME)
Beispiel #8
0
 def load_best(self, learner: RNN_Learner) -> bool:
     try:
         learner.load(BEST_MODEL_NAME)
         return True
     except FileNotFoundError:
         return False