コード例 #1
0
def save_epoch_status(status_dict, config):
    epoch_str = _get_epoch_str(status_dict, config)
    step_str = _get_step_str(status_dict, config)
    train_loss, train_acc, eval_loss, eval_acc = _get_values("epoch", status_dict)

    acc_str = status_dict["acc_str"]

    best_epoch_str = _get_best_values_str("epoch", status_dict)
    best_step_str = _get_best_values_str("step", status_dict)

    epoch_elapsed_time = status_dict["epoch_elapsed_time"]
    so_far_elapsed_time = status_dict["so_far_elapsed_time"]
    # pylint: disable=bad-continuation
    save_str = '[E:%s] [S:%s] [Train Loss:%.6f Acc:%.6f] [Val Loss:%.6f Acc:%.6f %s] [Best Epoch:%s] [Best Step:%s] [%.2fs %.1fs] Epoch status' % (
                            epoch_str,
                            step_str,
                            train_loss,
                            train_acc,
                            eval_loss,
                            eval_acc,
                            acc_str,
                            best_epoch_str,
                            best_step_str,
                            epoch_elapsed_time,
                            so_far_elapsed_time,
                            )

    # pylint: enable=bad-continuation
    record.save_content(config, save_str)
コード例 #2
0
def save_step_status(status_dict, config):
    category_name = "step"
    epoch_str = _get_epoch_str(status_dict, config)
    step_str = _get_step_str(status_dict, config)
    train_loss, train_acc, eval_loss, eval_acc = _get_values(category_name, status_dict)
    acc_str = status_dict["acc_str"]
    progress_str = status_dict["progress_str"]
    best_epoch_str = _get_best_values_str("epoch", status_dict)
    best_step_str = _get_best_values_str("step", status_dict)
    # pylint: disable=bad-continuation
    save_str = '[E:%s] [S:%s] [Train Loss:%.6f Acc:%.6f%s] [Val Loss:%.6f Acc:%.6f %s] [Best Epoch:%s] [Best Step:%s] Step status' % (
                            epoch_str,
                            step_str,
                            train_loss,
                            train_acc,
                            progress_str,
                            eval_loss,
                            eval_acc,
                            acc_str,
                            best_epoch_str,
                            best_step_str,
                            )

    # pylint: enable=bad-continuation
    record.save_content(config, save_str)
コード例 #3
0
    def train_the_model(self):
        model = self.graph_dict["model"]
        record.save_content(self.config, model)
        num_epochs = self.config["epochs"]

        self.status_dict["best_model_wts"] = copy.deepcopy(model.state_dict())
        epoch_start = self.status_dict["epoch"]

        for epoch in range(epoch_start + 1, num_epochs):
            self._run_a_epoch(epoch)
コード例 #4
0
ファイル: main.py プロジェクト: darr/transfer_learn
def without_finetune():
    print_str = "run the model with finetune=False"
    config["finetune"] = False
    record.save_content(config, print_str)
    record.record_dict(config, config)
    data_dict = ants_bees_data_set.get_dataset_info_dict(config)

    g = TrainTransferLearnGraph(data_dict, config)
    g.train_the_model()

    test_g = TestTransferLearnGraph(data_dict, config)
    test_g.test_the_model()
コード例 #5
0
    def _check_max_stop(self, value, stop_category):
        if self.early_stop:
            return True

        config = self.config
        step = self.status_dict[stop_category]
        if step >= value:
            record.save_content(
                config, " Stop With Max %s: %s" % (stop_category, step))
            self.early_stop = True
            return True

        return False
コード例 #6
0
    def _check_early_stop(self, stop_category):
        if self.early_stop:
            return True
        config = self.config

        if config["early_stop_%s" % stop_category]:
            step = self.status_dict[stop_category]
            best_step = self.status_dict["best_%s" % stop_category]
            if step - best_step > config["early_stop_%s_limit" %
                                         stop_category]:
                record.save_content(
                    config, "Early Stop With %s: %s" % (stop_category, step))
                self.early_stop = True
                return True

        return False
コード例 #7
0
    def predict(self, input_line, rnn=None, n_predictions=3):
        if rnn is None:
            rnn = self.graph_dict["model"]
        all_categories = self.data_dict["all_categories"]
        config = self.config
        record.save_content(config, '\n> %s' % input_line)
        with torch.no_grad():
            output = self.get_model_output(
                name_dataset.line_to_tensor(input_line, config))
            record.save_content(config, output.numpy())
            #record.save_content(config, output.tolist())
            record.save_content(config, all_categories)
            topv, topi = output.topk(n_predictions, 1, True)
            predictions = []

            for i in range(n_predictions):
                value = topv[0][i].item()
                category_index = topi[0][i].item()
                record.save_content(
                    config,
                    '(%.2f) %s' % (value, all_categories[category_index]))
                predictions.append([value, all_categories[category_index]])
コード例 #8
0
 def _create_output(self):
     if (not record.check_output_file(self.config)) or (
             not self.config["train_load_check_point_file"]):
         model = self.graph_dict["model"]
         record.save_content(self.config, model)
         record.record_dict(self.config, self.config)
コード例 #9
0
 def samples(self, category, start_letters='ABC'):
     for start_letter in start_letters:
         record.save_content(self.config, "%s start:%s" % (category, start_letter))
         con = self.sample(category, start_letter)
         record.save_content(self.config, con)
コード例 #10
0
 def _init_best_epoch_model(self):
     record.save_content(self.config, "best epoch model")
     model = self.graph_dict["model"]
     model.load_state_dict(self.status_dict["best_epoch_model_wts"])
コード例 #11
0
def save_epoch_status(status_dict, config):
    save_str = _get_running_status_str(status_dict, "step", config)
    record.save_content(config, "%s Epoch Status" % save_str)