예제 #1
0
    def step_end(self, run_context):
        """step end and do evaluation"""
        cb_params = run_context.original_args()
        if cb_params.cur_step_num % 100 == 0:
            callback = Accuracy()
            columns_list = [
                "input_ids", "input_mask", "segment_ids", "label_ids"
            ]
            for data in self.dataset.create_dict_iterator():
                input_data = []
                for i in columns_list:
                    input_data.append(Tensor(data[i]))
                input_ids, input_mask, token_type_id, label_ids = input_data
                self.network.set_train(False)
                logits = self.network(input_ids, token_type_id, input_mask)
                callback.update(logits[3], label_ids)
            acc = callback.acc_num / callback.total_num
            with open("./eval.log", "a+") as f:
                f.write("acc_num {}, total_num{}, accuracy{:.6f}".format(
                    callback.acc_num, callback.total_num,
                    callback.acc_num / callback.total_num))
                f.write('\n')

            if acc > self.global_acc:
                self.global_acc = acc
                print("The best acc is {}".format(acc))
                eval_model_ckpt_file = "eval_model.ckpt"
                if os.path.exists(eval_model_ckpt_file):
                    os.remove(eval_model_ckpt_file)
                _exec_save_checkpoint(self.network, eval_model_ckpt_file)
예제 #2
0
def test_exec_save_checkpoint():
    net = Net()
    loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
    opt = Momentum(net.trainable_params(), 0.0, 0.9, 0.0001, 1024)

    loss_net = WithLossCell(net, loss)
    train_network = TrainOneStepCell(loss_net, opt)
    _exec_save_checkpoint(train_network, ckpoint_file_name="./new_ckpt.ckpt")

    load_checkpoint("new_ckpt.ckpt")
예제 #3
0
 def step_end(self, run_context):
     """step end and save ckpt"""
     cb_params = run_context.original_args()
     if cb_params.cur_step_num % self.save_ckpt_step == 0:
         saved_ckpt_num = cb_params.cur_step_num / self.save_ckpt_step
         if saved_ckpt_num > self.max_ckpt_num:
             oldest_ckpt_index = saved_ckpt_num - self.max_ckpt_num
             path = os.path.join(self.output_dir, "tiny_bert_{}_{}.ckpt".format(int(oldest_ckpt_index),
                                                                                self.save_ckpt_step))
             if os.path.exists(path):
                 os.remove(path)
         _exec_save_checkpoint(self.network, os.path.join(self.output_dir,
                                                          "tiny_bert_{}_{}.ckpt".format(int(saved_ckpt_num),
                                                                                        self.save_ckpt_step)))
예제 #4
0
    def _save_ckpt(self, cb_params, force_to_save=False):
        """Save checkpoint files."""
        if cb_params.cur_step_num == self._last_triggered_step:
            return

        save_ckpt = self._check_save_ckpt(cb_params, force_to_save)
        step_num_in_epoch = (cb_params.cur_step_num -
                             1) % cb_params.batch_num + 1

        if save_ckpt:
            cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \
                               + str(step_num_in_epoch) + ".ckpt"
            # update checkpoint file list.
            self._manager.update_ckpoint_filelist(self._directory,
                                                  self._prefix)
            # keep checkpoint files number equal max number.
            if self._config.keep_checkpoint_max and 0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num:
                self._manager.remove_oldest_ckpoint_file()
            elif self._config.keep_checkpoint_per_n_minutes and self._config.keep_checkpoint_per_n_minutes > 0:
                self._cur_time_for_keep = time.time()
                if (self._cur_time_for_keep - self._last_time_for_keep) \
                        < self._config.keep_checkpoint_per_n_minutes * 60:
                    self._manager.keep_one_ckpoint_per_minutes(
                        self._config.keep_checkpoint_per_n_minutes,
                        self._cur_time_for_keep)

            # generate the new checkpoint file and rename it.
            global _save_dir
            _save_dir = self._directory
            cur_file = os.path.join(self._directory, cur_ckpoint_file)
            tmp_ckpt_file_name_for_cur_process = str(
                os.getpid()) + "-" + 'parameters.ckpt'
            gen_file = os.path.join(_save_dir,
                                    tmp_ckpt_file_name_for_cur_process)
            self._last_time_for_keep = time.time()
            self._last_triggered_step = cb_params.cur_step_num

            if context.get_context("enable_ge"):
                set_cur_net(cb_params.train_network)
                cb_params.train_network.exec_checkpoint_graph()

            _exec_save_checkpoint(cb_params.train_network, gen_file,
                                  self._config.integrated_save)

            if os.path.exists(gen_file):
                shutil.move(gen_file, cur_file)
            self._latest_ckpt_file_name = cur_file
예제 #5
0
    def _save_ckpt(self, cb_params, force_to_save=False):
        """Save checkpoint files."""
        if cb_params.cur_step_num == self._last_triggered_step:
            return

        save_ckpt = self._check_save_ckpt(cb_params, force_to_save)
        step_num_in_epoch = (cb_params.cur_step_num -
                             1) % cb_params.batch_num + 1

        if save_ckpt:
            cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \
                               + str(step_num_in_epoch) + ".ckpt"
            if os.getenv("MS_ROLE") == "MS_PSERVER":
                from mindspore.parallel._ps_utils import _get_ps_mode_rank
                cur_ckpoint_file = "PServer_" + str(
                    _get_ps_mode_rank()) + "_" + cur_ckpoint_file
            # update checkpoint file list.
            self._manager.update_ckpoint_filelist(self._directory,
                                                  self._prefix)
            # keep checkpoint files number equal max number.
            if self._config.keep_checkpoint_max and 0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num:
                self._manager.remove_oldest_ckpoint_file()
            elif self._config.keep_checkpoint_per_n_minutes and self._config.keep_checkpoint_per_n_minutes > 0:
                self._cur_time_for_keep = time.time()
                if (self._cur_time_for_keep - self._last_time_for_keep) \
                        < self._config.keep_checkpoint_per_n_minutes * 60:
                    self._manager.keep_one_ckpoint_per_minutes(
                        self._config.keep_checkpoint_per_n_minutes,
                        self._cur_time_for_keep)

            # generate the new checkpoint file and rename it.
            global _save_dir
            _save_dir = self._directory
            cur_file = os.path.join(self._directory, cur_ckpoint_file)
            self._last_time_for_keep = time.time()
            self._last_triggered_step = cb_params.cur_step_num

            if context.get_context("enable_ge"):
                set_cur_net(cb_params.train_network)
                cb_params.train_network.exec_checkpoint_graph()

            _exec_save_checkpoint(cb_params.train_network, cur_file,
                                  self._config.integrated_save,
                                  self._config.async_save)

            self._latest_ckpt_file_name = cur_file
예제 #6
0
파일: train.py 프로젝트: xyg320/mindspore
def train():
    """Train GAT model."""
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir',
                        type=str,
                        default='./data/cora/cora_mr',
                        help='Data dir')
    parser.add_argument('--train_nodes_num',
                        type=int,
                        default=140,
                        help='Nodes numbers for training')
    parser.add_argument('--eval_nodes_num',
                        type=int,
                        default=500,
                        help='Nodes numbers for evaluation')
    parser.add_argument('--test_nodes_num',
                        type=int,
                        default=1000,
                        help='Nodes numbers for test')
    args = parser.parse_args()
    if not os.path.exists("ckpts"):
        os.mkdir("ckpts")
    context.set_context(mode=context.GRAPH_MODE,
                        device_target="Ascend",
                        save_graphs=False)
    # train parameters
    hid_units = GatConfig.hid_units
    n_heads = GatConfig.n_heads
    early_stopping = GatConfig.early_stopping
    lr = GatConfig.lr
    l2_coeff = GatConfig.l2_coeff
    num_epochs = GatConfig.num_epochs
    feature, biases, y_train, train_mask, y_val, eval_mask, y_test, test_mask = load_and_process(
        args.data_dir, args.train_nodes_num, args.eval_nodes_num,
        args.test_nodes_num)
    feature_size = feature.shape[2]
    num_nodes = feature.shape[1]
    num_class = y_train.shape[2]

    gat_net = GAT(feature,
                  biases,
                  feature_size,
                  num_class,
                  num_nodes,
                  hid_units,
                  n_heads,
                  attn_drop=GatConfig.attn_dropout,
                  ftr_drop=GatConfig.feature_dropout)
    gat_net.add_flags_recursive(fp16=True)

    eval_net = LossAccuracyWrapper(gat_net, num_class, y_val, eval_mask,
                                   l2_coeff)

    train_net = TrainGAT(gat_net, num_class, y_train, train_mask, lr, l2_coeff)

    train_net.set_train(True)
    val_acc_max = 0.0
    val_loss_min = np.inf
    for _epoch in range(num_epochs):
        train_result = train_net()
        train_loss = train_result[0].asnumpy()
        train_acc = train_result[1].asnumpy()

        eval_result = eval_net()
        eval_loss = eval_result[0].asnumpy()
        eval_acc = eval_result[1].asnumpy()

        print(
            "Epoch:{}, train loss={:.5f}, train acc={:.5f} | val loss={:.5f}, val acc={:.5f}"
            .format(_epoch, train_loss, train_acc, eval_loss, eval_acc))
        if eval_acc >= val_acc_max or eval_loss < val_loss_min:
            if eval_acc >= val_acc_max and eval_loss < val_loss_min:
                val_acc_model = eval_acc
                val_loss_model = eval_loss
                _exec_save_checkpoint(train_net.network, "ckpts/gat.ckpt")
            val_acc_max = np.max((val_acc_max, eval_acc))
            val_loss_min = np.min((val_loss_min, eval_loss))
            curr_step = 0
        else:
            curr_step += 1
            if curr_step == early_stopping:
                print("Early Stop Triggered!, Min loss: {}, Max accuracy: {}".
                      format(val_loss_min, val_acc_max))
                print(
                    "Early stop model validation loss: {}, accuracy{}".format(
                        val_loss_model, val_acc_model))
                break
    gat_net_test = GAT(feature,
                       biases,
                       feature_size,
                       num_class,
                       num_nodes,
                       hid_units,
                       n_heads,
                       attn_drop=0.0,
                       ftr_drop=0.0)
    load_checkpoint("ckpts/gat.ckpt", net=gat_net_test)
    gat_net_test.add_flags_recursive(fp16=True)

    test_net = LossAccuracyWrapper(gat_net_test, num_class, y_test, test_mask,
                                   l2_coeff)
    test_result = test_net()
    print("Test loss={}, test acc={}".format(test_result[0], test_result[1]))