def step_end(self, run_context): """step end and do evaluation""" cb_params = run_context.original_args() if cb_params.cur_step_num % 100 == 0: callback = Accuracy() columns_list = [ "input_ids", "input_mask", "segment_ids", "label_ids" ] for data in self.dataset.create_dict_iterator(): input_data = [] for i in columns_list: input_data.append(Tensor(data[i])) input_ids, input_mask, token_type_id, label_ids = input_data self.network.set_train(False) logits = self.network(input_ids, token_type_id, input_mask) callback.update(logits[3], label_ids) acc = callback.acc_num / callback.total_num with open("./eval.log", "a+") as f: f.write("acc_num {}, total_num{}, accuracy{:.6f}".format( callback.acc_num, callback.total_num, callback.acc_num / callback.total_num)) f.write('\n') if acc > self.global_acc: self.global_acc = acc print("The best acc is {}".format(acc)) eval_model_ckpt_file = "eval_model.ckpt" if os.path.exists(eval_model_ckpt_file): os.remove(eval_model_ckpt_file) _exec_save_checkpoint(self.network, eval_model_ckpt_file)
def test_exec_save_checkpoint(): net = Net() loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) opt = Momentum(net.trainable_params(), 0.0, 0.9, 0.0001, 1024) loss_net = WithLossCell(net, loss) train_network = TrainOneStepCell(loss_net, opt) _exec_save_checkpoint(train_network, ckpoint_file_name="./new_ckpt.ckpt") load_checkpoint("new_ckpt.ckpt")
def step_end(self, run_context): """step end and save ckpt""" cb_params = run_context.original_args() if cb_params.cur_step_num % self.save_ckpt_step == 0: saved_ckpt_num = cb_params.cur_step_num / self.save_ckpt_step if saved_ckpt_num > self.max_ckpt_num: oldest_ckpt_index = saved_ckpt_num - self.max_ckpt_num path = os.path.join(self.output_dir, "tiny_bert_{}_{}.ckpt".format(int(oldest_ckpt_index), self.save_ckpt_step)) if os.path.exists(path): os.remove(path) _exec_save_checkpoint(self.network, os.path.join(self.output_dir, "tiny_bert_{}_{}.ckpt".format(int(saved_ckpt_num), self.save_ckpt_step)))
def _save_ckpt(self, cb_params, force_to_save=False): """Save checkpoint files.""" if cb_params.cur_step_num == self._last_triggered_step: return save_ckpt = self._check_save_ckpt(cb_params, force_to_save) step_num_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 if save_ckpt: cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \ + str(step_num_in_epoch) + ".ckpt" # update checkpoint file list. self._manager.update_ckpoint_filelist(self._directory, self._prefix) # keep checkpoint files number equal max number. if self._config.keep_checkpoint_max and 0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num: self._manager.remove_oldest_ckpoint_file() elif self._config.keep_checkpoint_per_n_minutes and self._config.keep_checkpoint_per_n_minutes > 0: self._cur_time_for_keep = time.time() if (self._cur_time_for_keep - self._last_time_for_keep) \ < self._config.keep_checkpoint_per_n_minutes * 60: self._manager.keep_one_ckpoint_per_minutes( self._config.keep_checkpoint_per_n_minutes, self._cur_time_for_keep) # generate the new checkpoint file and rename it. global _save_dir _save_dir = self._directory cur_file = os.path.join(self._directory, cur_ckpoint_file) tmp_ckpt_file_name_for_cur_process = str( os.getpid()) + "-" + 'parameters.ckpt' gen_file = os.path.join(_save_dir, tmp_ckpt_file_name_for_cur_process) self._last_time_for_keep = time.time() self._last_triggered_step = cb_params.cur_step_num if context.get_context("enable_ge"): set_cur_net(cb_params.train_network) cb_params.train_network.exec_checkpoint_graph() _exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save) if os.path.exists(gen_file): shutil.move(gen_file, cur_file) self._latest_ckpt_file_name = cur_file
def _save_ckpt(self, cb_params, force_to_save=False): """Save checkpoint files.""" if cb_params.cur_step_num == self._last_triggered_step: return save_ckpt = self._check_save_ckpt(cb_params, force_to_save) step_num_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 if save_ckpt: cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \ + str(step_num_in_epoch) + ".ckpt" if os.getenv("MS_ROLE") == "MS_PSERVER": from mindspore.parallel._ps_utils import _get_ps_mode_rank cur_ckpoint_file = "PServer_" + str( _get_ps_mode_rank()) + "_" + cur_ckpoint_file # update checkpoint file list. self._manager.update_ckpoint_filelist(self._directory, self._prefix) # keep checkpoint files number equal max number. if self._config.keep_checkpoint_max and 0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num: self._manager.remove_oldest_ckpoint_file() elif self._config.keep_checkpoint_per_n_minutes and self._config.keep_checkpoint_per_n_minutes > 0: self._cur_time_for_keep = time.time() if (self._cur_time_for_keep - self._last_time_for_keep) \ < self._config.keep_checkpoint_per_n_minutes * 60: self._manager.keep_one_ckpoint_per_minutes( self._config.keep_checkpoint_per_n_minutes, self._cur_time_for_keep) # generate the new checkpoint file and rename it. global _save_dir _save_dir = self._directory cur_file = os.path.join(self._directory, cur_ckpoint_file) self._last_time_for_keep = time.time() self._last_triggered_step = cb_params.cur_step_num if context.get_context("enable_ge"): set_cur_net(cb_params.train_network) cb_params.train_network.exec_checkpoint_graph() _exec_save_checkpoint(cb_params.train_network, cur_file, self._config.integrated_save, self._config.async_save) self._latest_ckpt_file_name = cur_file
def train(): """Train GAT model.""" parser = argparse.ArgumentParser() parser.add_argument('--data_dir', type=str, default='./data/cora/cora_mr', help='Data dir') parser.add_argument('--train_nodes_num', type=int, default=140, help='Nodes numbers for training') parser.add_argument('--eval_nodes_num', type=int, default=500, help='Nodes numbers for evaluation') parser.add_argument('--test_nodes_num', type=int, default=1000, help='Nodes numbers for test') args = parser.parse_args() if not os.path.exists("ckpts"): os.mkdir("ckpts") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) # train parameters hid_units = GatConfig.hid_units n_heads = GatConfig.n_heads early_stopping = GatConfig.early_stopping lr = GatConfig.lr l2_coeff = GatConfig.l2_coeff num_epochs = GatConfig.num_epochs feature, biases, y_train, train_mask, y_val, eval_mask, y_test, test_mask = load_and_process( args.data_dir, args.train_nodes_num, args.eval_nodes_num, args.test_nodes_num) feature_size = feature.shape[2] num_nodes = feature.shape[1] num_class = y_train.shape[2] gat_net = GAT(feature, biases, feature_size, num_class, num_nodes, hid_units, n_heads, attn_drop=GatConfig.attn_dropout, ftr_drop=GatConfig.feature_dropout) gat_net.add_flags_recursive(fp16=True) eval_net = LossAccuracyWrapper(gat_net, num_class, y_val, eval_mask, l2_coeff) train_net = TrainGAT(gat_net, num_class, y_train, train_mask, lr, l2_coeff) train_net.set_train(True) val_acc_max = 0.0 val_loss_min = np.inf for _epoch in range(num_epochs): train_result = train_net() train_loss = train_result[0].asnumpy() train_acc = train_result[1].asnumpy() eval_result = eval_net() eval_loss = eval_result[0].asnumpy() eval_acc = eval_result[1].asnumpy() print( "Epoch:{}, train loss={:.5f}, train acc={:.5f} | val loss={:.5f}, val acc={:.5f}" .format(_epoch, train_loss, train_acc, eval_loss, eval_acc)) if eval_acc >= val_acc_max or eval_loss < val_loss_min: if eval_acc >= val_acc_max and eval_loss < val_loss_min: val_acc_model = eval_acc val_loss_model = eval_loss _exec_save_checkpoint(train_net.network, "ckpts/gat.ckpt") val_acc_max = np.max((val_acc_max, eval_acc)) val_loss_min = np.min((val_loss_min, eval_loss)) curr_step = 0 else: curr_step += 1 if curr_step == early_stopping: print("Early Stop Triggered!, Min loss: {}, Max accuracy: {}". format(val_loss_min, val_acc_max)) print( "Early stop model validation loss: {}, accuracy{}".format( val_loss_model, val_acc_model)) break gat_net_test = GAT(feature, biases, feature_size, num_class, num_nodes, hid_units, n_heads, attn_drop=0.0, ftr_drop=0.0) load_checkpoint("ckpts/gat.ckpt", net=gat_net_test) gat_net_test.add_flags_recursive(fp16=True) test_net = LossAccuracyWrapper(gat_net_test, num_class, y_test, test_mask, l2_coeff) test_result = test_net() print("Test loss={}, test acc={}".format(test_result[0], test_result[1]))