def save_epoch_status(status_dict, config): epoch_str = _get_epoch_str(status_dict, config) step_str = _get_step_str(status_dict, config) train_loss, train_acc, eval_loss, eval_acc = _get_values("epoch", status_dict) acc_str = status_dict["acc_str"] best_epoch_str = _get_best_values_str("epoch", status_dict) best_step_str = _get_best_values_str("step", status_dict) epoch_elapsed_time = status_dict["epoch_elapsed_time"] so_far_elapsed_time = status_dict["so_far_elapsed_time"] # pylint: disable=bad-continuation save_str = '[E:%s] [S:%s] [Train Loss:%.6f Acc:%.6f] [Val Loss:%.6f Acc:%.6f %s] [Best Epoch:%s] [Best Step:%s] [%.2fs %.1fs] Epoch status' % ( epoch_str, step_str, train_loss, train_acc, eval_loss, eval_acc, acc_str, best_epoch_str, best_step_str, epoch_elapsed_time, so_far_elapsed_time, ) # pylint: enable=bad-continuation record.save_content(config, save_str)
def save_step_status(status_dict, config): category_name = "step" epoch_str = _get_epoch_str(status_dict, config) step_str = _get_step_str(status_dict, config) train_loss, train_acc, eval_loss, eval_acc = _get_values(category_name, status_dict) acc_str = status_dict["acc_str"] progress_str = status_dict["progress_str"] best_epoch_str = _get_best_values_str("epoch", status_dict) best_step_str = _get_best_values_str("step", status_dict) # pylint: disable=bad-continuation save_str = '[E:%s] [S:%s] [Train Loss:%.6f Acc:%.6f%s] [Val Loss:%.6f Acc:%.6f %s] [Best Epoch:%s] [Best Step:%s] Step status' % ( epoch_str, step_str, train_loss, train_acc, progress_str, eval_loss, eval_acc, acc_str, best_epoch_str, best_step_str, ) # pylint: enable=bad-continuation record.save_content(config, save_str)
def train_the_model(self): model = self.graph_dict["model"] record.save_content(self.config, model) num_epochs = self.config["epochs"] self.status_dict["best_model_wts"] = copy.deepcopy(model.state_dict()) epoch_start = self.status_dict["epoch"] for epoch in range(epoch_start + 1, num_epochs): self._run_a_epoch(epoch)
def without_finetune(): print_str = "run the model with finetune=False" config["finetune"] = False record.save_content(config, print_str) record.record_dict(config, config) data_dict = ants_bees_data_set.get_dataset_info_dict(config) g = TrainTransferLearnGraph(data_dict, config) g.train_the_model() test_g = TestTransferLearnGraph(data_dict, config) test_g.test_the_model()
def _check_max_stop(self, value, stop_category): if self.early_stop: return True config = self.config step = self.status_dict[stop_category] if step >= value: record.save_content( config, " Stop With Max %s: %s" % (stop_category, step)) self.early_stop = True return True return False
def _check_early_stop(self, stop_category): if self.early_stop: return True config = self.config if config["early_stop_%s" % stop_category]: step = self.status_dict[stop_category] best_step = self.status_dict["best_%s" % stop_category] if step - best_step > config["early_stop_%s_limit" % stop_category]: record.save_content( config, "Early Stop With %s: %s" % (stop_category, step)) self.early_stop = True return True return False
def predict(self, input_line, rnn=None, n_predictions=3): if rnn is None: rnn = self.graph_dict["model"] all_categories = self.data_dict["all_categories"] config = self.config record.save_content(config, '\n> %s' % input_line) with torch.no_grad(): output = self.get_model_output( name_dataset.line_to_tensor(input_line, config)) record.save_content(config, output.numpy()) #record.save_content(config, output.tolist()) record.save_content(config, all_categories) topv, topi = output.topk(n_predictions, 1, True) predictions = [] for i in range(n_predictions): value = topv[0][i].item() category_index = topi[0][i].item() record.save_content( config, '(%.2f) %s' % (value, all_categories[category_index])) predictions.append([value, all_categories[category_index]])
def _create_output(self): if (not record.check_output_file(self.config)) or ( not self.config["train_load_check_point_file"]): model = self.graph_dict["model"] record.save_content(self.config, model) record.record_dict(self.config, self.config)
def samples(self, category, start_letters='ABC'): for start_letter in start_letters: record.save_content(self.config, "%s start:%s" % (category, start_letter)) con = self.sample(category, start_letter) record.save_content(self.config, con)
def _init_best_epoch_model(self): record.save_content(self.config, "best epoch model") model = self.graph_dict["model"] model.load_state_dict(self.status_dict["best_epoch_model_wts"])
def save_epoch_status(status_dict, config): save_str = _get_running_status_str(status_dict, "step", config) record.save_content(config, "%s Epoch Status" % save_str)