def load_pretrained_langmodel(self, rnn_learner: RNN_Learner) -> None: if self.is_lang_model: raise AssertionError("This method can only be called for classifiers.") logger.debug(f"Copying {self.path_to_base_torch_encoder} to {self.path_to_new_encoder_location}") shutil.copy(self.path_to_base_torch_encoder, self.path_to_new_encoder_location) rnn_learner.load_encoder(LM_ENCODER_NAME)
def find_and_plot_lr(rnn_learner: RNN_Learner, fs: FS): logger.info("Looking for the best learning rate...") # TODO we shouldnt pass file argument, when looking # for learning rate we should log to console rnn_learner.lr_find( file=open(os.path.join(fs.path_to_model, 'training.log'), 'w')) rnn_learner.sched.plot(fs.path_to_lr_plot) logger.info(f"Plot is saved to {fs.path_to_lr_plot}")
def gen_text(learner: RNN_Learner, starting_words_list: List[str], how_many_to_gen: int) -> List[str]: text = [] t = to_gpu(learner.text_field.numericalize([starting_words_list], -1)) res, *_ = learner.model(t) for i in range(how_many_to_gen): n = torch.multinomial(res[-1].exp(), 1) # n = n[1] if n.data[0] == 0 else n[0] text.append(learner.text_field.vocab.itos[n.data[0]]) res, *_ = learner.model(n[0].unsqueeze(0)) return text
def get_transformer_model(self, opt_fn, emb_sz, max_seq_len, **kwargs): m = get_transformer_language_model(self.n_tok, max_seq_len, self.trn_dl.target_length, emb_sz, pad_token=self.pad_idx, **kwargs) model = TransformerLanguageModel(to_gpu(m)) return RNN_Learner(self, model, opt_fn=opt_fn)
def train_and_save_model(rnn_learner: RNN_Learner, fs: FS, training: LMTraining, metric_list: List[str], cache: Cache, use_subword_aware_metrics: bool): only_validation = False n = training.cycle.n if training.cycle.n == 0: logger.info("Number of epochs specified is 0. Not training...") fs.save_best(rnn_learner) only_validation = True n = 1 training_start_time = time() training_log_file = os.path.join(fs.path_to_model, 'training.log') logger.info( f"Starting training, check {training_log_file} for training progress") callbacks = [] if training.early_stop: callbacks.append( EarlyStopping(rnn_learner, save_path=BEST_MODEL_NAME, best_loss_path=BEST_LOSS_FILENAME, best_acc_path=BEST_ACC_FILENAME, best_epoch_path=BEST_EPOCH_FILENAME, enc_path=ENCODER_NAME)) validation_function = get_validation_function(cache, use_subword_aware_metrics, rnn_learner.text_field) vals, ep_vals = rnn_learner.fit(lrs=training.lr, n_cycle=n, wds=training.wds, cycle_len=training.cycle.len, cycle_mult=training.cycle.mult, metrics=list( map(lambda x: getattr(metrics, x), metric_list)), get_ep_vals=True, file=open(training_log_file, 'w'), callbacks=callbacks, valid_func=validation_function, only_validation=only_validation) training_time_mins = int(time() - training_start_time) // 60 with open(os.path.join(fs.path_to_model, 'results.out'), 'w') as f: f.write(str(training_time_mins) + "\n") for _, vals in ep_vals.items(): f.write(" ".join(map(lambda x: str(x), vals)) + "\n")
def train(fs: FS, rnn_learner: RNN_Learner, training: ClassifierTraining, metric_list: List[str]): training_log_file = os.path.join(fs.path_to_model, 'training.log') if not training.stages: logger.warning("No stages specified in the config") return logger.info( f"Starting training, check {training_log_file} for training progress") for idx, stage in enumerate(training.stages): training_start_time = time() logger.info(f'----- Running stage {idx}') cycle = stage.cycle only_validation = False n = cycle.n if cycle.n == 0 or cycle.len == 0: logger.warning( "Number of epochs specified at this stage is 0. Not training..." ) only_validation = True n = 1 callbacks = [] if training.early_stop: name_suffix = f'.{idx}' if idx < len(training.stages) - 1 else '' callbacks.append( EarlyStopping( rnn_learner, save_path=BEST_MODEL_NAME + name_suffix, best_loss_path=BEST_LOSS_FILENAME + name_suffix, best_acc_path=BEST_ACC_FILENAME + name_suffix, best_epoch_path=BEST_EPOCH_FILENAME + name_suffix, )) rnn_learner.freeze_to(stage.freeze_to) lrs = training.lrs lr_list = [lrs.base_lr / lrs.factor**m for m in lrs.multipliers] logger.debug(f'Using the following learning rates: {lr_list}') vals, ep_vals = rnn_learner.fit(lrs=lr_list, metrics=list( map(lambda x: getattr(metrics, x), metric_list)), wds=training.wds, cycle_len=cycle.len, n_cycle=n, cycle_mult=cycle.mult, callbacks=callbacks, get_ep_vals=True, file=open(training_log_file, 'w+'), only_validation=only_validation) training_time_mins = int(time() - training_start_time) // 60 with open(os.path.join(fs.path_to_model, 'results.out'), 'a+') as f: f.write(str(training_time_mins) + "\n") for _, vals in ep_vals.items(): f.write(" ".join(map(lambda x: str(x), vals)) + "\n") # logger.info(f'Current accuracy is ...') # logger.info(f' ... {accuracy_gen(*rnn_learner.predict_with_targs())}') # rnn_learner.sched.plot_loss() return rnn_learner
def load_base_model(self, rnn_learner: RNN_Learner) -> None: logger.debug(f"Copying from {self.path_to_base_torch_model} to {self.path_to_new_torch_model_location}") shutil.copy(self.path_to_base_torch_model, self.path_to_new_torch_model_location) rnn_learner.load(BEST_BASE_MODEL_NAME)
def load_best(self, learner: RNN_Learner) -> bool: try: learner.load(BEST_MODEL_NAME) return True except FileNotFoundError: return False