def fix_ckpt(self, ckpt_path, new_ckpt_path): """Fix checkpoint file.""" param_dict = load_checkpoint(ckpt_path) main_module_name = self._fixed_mapper['main_module_name'] fixed_variable_dict = self._fixed_mapper['fixed_variable_mapper'] fixed_module_dict = self._fixed_mapper['fixed_module_mapper'] save_obj = list() for weight_name, weight_value in param_dict.items(): weight_name_scopes = weight_name.split('.') weight_name_scopes.insert(0, main_module_name) for idx, w in enumerate(weight_name_scopes[:-1]): for fixed_variable_module, fixed_variable_name_mapper in fixed_variable_dict.items( ): if re.match( fixed_variable_module, fixed_module_dict.get('_'.join(w.split('_')[:-1]), w)): weight_name = weight_name.replace( weight_name_scopes[idx + 1], fixed_variable_name_mapper.get( weight_name_scopes[idx + 1], weight_name_scopes[idx + 1])) obj = {'name': weight_name, 'data': Tensor(weight_value)} save_obj.append(obj) save_checkpoint(save_obj, new_ckpt_path) logger_console.info(f'Saved new checkpoint file to {new_ckpt_path}.')
def epoch_end(self, run_context): if self.step_eval: return None flag = '' cb_params = run_context.original_args() cur_epoch = cb_params.cur_epoch_num if cur_epoch > 0 and cur_epoch % self.eval_epoch == 0: acc = self.eval() if acc > self.best_acc: flag = '↑' self.best_acc = acc save_checkpoint(self.model._network, self.best_ckpt) self.logger.update_acc_ckpt(acc, self.best_ckpt) else: if acc > self.threshold: self.patience_count += 1 if self.patience_count > self.patience: param_dict = load_checkpoint( ckpt_file_name=self.best_ckpt, net=self.model._network) load_param_into_net(net=self.model._network, parameter_dict=param_dict) self.patience_count = 0 print( f'* acc for epoch: {cur_epoch} is {acc * 100}%{flag}, best acc is {self.best_acc * 100}%' )
def convert_weights( model_type: str, from_model: str, from_path: str, config_path: str, to_model: str, dump_path: str, ): model_class = MODEL_CLASSES[to_model][model_type] config = ConfigBase(config_path) model = model_class(config) load_weights_fct = LOAD_WEIGHTS_MAPS[to_model][model_type][from_model] if to_model == "tf": input_ids = tf.ones([3, 4], dtype=tf.int32) model(input_ids) load_weights_fct(model, config, from_path) if to_model == "pt": torch.save(model.state_dict(), dump_path) elif to_model == "tf": model.save_weights(dump_path) elif to_model == "ms": mindspore.save_checkpoint(model, dump_path) elif to_model == "of": flow.save(model.state_dict(), dump_path) elif to_model == "pd": paddle.save(model.state_dict(), dump_path) print("Save {} model to {}".format(to_model, dump_path))
def end(self, run_context): cb_params = run_context.original_args() print("epoch: {:3d}/{:3d}, avg loss:{:5.3f}".format( self.epochs, self.epochs, np.mean(self.losses)), flush=True) if self.save_checkpoint: save_checkpoint( cb_params.train_network, os.path.join(self.save_checkpoint_path, f"naml_last.ckpt"))
def step_end(self, run_context): """step end""" cb_params = run_context.original_args() result = self.model.eval(self.ds_eval) if result['Accuracy'] > self.acc: self.acc = result['Accuracy'] file_name = str(self.acc) + ".ckpt" save_checkpoint(save_obj=cb_params.train_network, ckpt_file_name=file_name) print("Save the maximum accuracy checkpoint, the accuracy is", self.acc)
def csd_train(train_loader, net, opt): set_seed(1) device_id = int(os.getenv('DEVICE_ID', '0')) print("[CSD] Start Training...") step_size = train_loader.get_dataset_size() lr = [] for i in range(0, opt.epochs): cur_lr = opt.lr / (2 ** ((i + 1) // 200)) lr.extend([cur_lr] * step_size) optim = nn.Adam(net.trainable_params(), learning_rate=lr, loss_scale=opt.loss_scale) # net_with_loss = NetWithLossCell(net) net_with_loss = NetWithCSDLossCell(net, args.contra_lambda, args.neg_num) train_cell = TrainOneStepCell(net_with_loss, optim) net.set_train() eval_net = net # time_cb = TimeMonitor(data_size=step_size) # loss_cb = LossMonitor() # metrics = { # "psnr": PSNR(rgb_range=opt.rgb_range, shave=True), # } # eval_cb = EvalCallBack(eval_net, eval_ds, args.test_every, step_size / opt.batch_size, metrics=metrics, # rank_id=rank_id) # cb = [time_cb, loss_cb] # config_ck = CheckpointConfig(save_checkpoint_steps=opt.ckpt_save_interval * step_size, # keep_checkpoint_max=opt.ckpt_save_max) # ckpt_cb = ModelCheckpoint(prefix=opt.filename, directory=opt.ckpt_save_path, config=config_ck) # if device_id == 0: # cb += [ckpt_cb] for epoch in range(0, opt.epochs): epoch_loss = 0 for iteration, batch in enumerate(train_loader.create_dict_iterator(), 1): lr = batch["LR"] hr = batch["HR"] loss = train_cell(lr, hr, Tensor(opt.stu_width_mult), Tensor(1.0)) epoch_loss += loss print(f"Epoch[{epoch}] loss: {epoch_loss.asnumpy()}") # with eval_net.set_train(False): # do_eval(eval_ds, eval_net) if (epoch) % 10 == 0: print('===> Saving model...') save_checkpoint(net, f'./ckpt/{opt.filename}.ckpt')
def epoch_end(self, run_context): """Callback when epoch end.""" cb_params = run_context.original_args() cur_epoch = cb_params.cur_epoch_num if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0: eval_start = time.time() res = self.eval_function(self.eval_param_dict) eval_cost = time.time() - eval_start print("epoch: {}, {}: {}, eval_cost:{:.2f}".format(cur_epoch, self.metrics_name, res, eval_cost), flush=True) if res >= self.best_res: self.best_res = res self.best_epoch = cur_epoch print("update best result: {}".format(res), flush=True) if self.save_best_ckpt: if os.path.exists(self.bast_ckpt_path): self.remove_ckpoint_file(self.bast_ckpt_path) save_checkpoint(cb_params.train_network, self.bast_ckpt_path) print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True)
def save_model(self, filename=None): """Saving the model to a file.""" model_name = Config().trainer.model_name model_dir = Config().params['model_dir'] if not os.path.exists(model_dir): os.makedirs(model_dir) if filename is not None: model_path = f'{model_dir}{filename}' else: model_path = f'{model_dir}{model_name}.ckpt' mindspore.save_checkpoint(self.model, model_path) if self.client_id == 0: logging.info("[Server #%d] Model saved to %s.", os.getpid(), model_path) else: logging.info("[Client #%d] Model saved to %s.", self.client_id, model_path)
def train_process(self, epoch, train_dataset, mini_steps=None): """ Training process. The data would be passed to network directly. """ dataset_helper = DatasetHelper(train_dataset, dataset_sink_mode=False, epoch_num=epoch) for i in range(epoch): step = 0 for k, next_element in enumerate(dataset_helper): loss = self._train_forward_backward(*next_element) if (k + 1) % mini_steps == 0: step += 1 print("epoch:", i + 1, "step:", step, "loss is ", loss) self._train_optim() self._train_clear() train_dataset.reset() save_checkpoint(self._train_forward_backward, "gradient_accumulation.ckpt")
def epoch_end(self, run_context): """Callback when epoch end.""" epoch_end_f = True if self.sink_mode: self.cur_step += self.epoch_steps epoch_end_f = False if self.cur_step >= self.dataset_size: epoch_end_f = True self.cur_step = self.cur_step % self.dataset_size cb_params = run_context.original_args() epoch_mseconds = (time.time() - self.epoch_time) * 1000 per_step_mseconds = epoch_mseconds / cb_params.batch_num step_loss = cb_params.net_outputs if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor): step_loss = step_loss[0] if isinstance(step_loss, Tensor): step_loss = np.mean(step_loss.asnumpy()) self.losses.append(step_loss) if epoch_end_f: print("epoch: {:3d}/{:3d}, avg loss:{:5.3f}".format( self.cur_epoch, self.epochs, np.mean(self.losses)), flush=True) self.losses = [] self.cur_epoch += 1 if self.sink_mode: print( "epoch: {:3d}/{:3d}, step:{:5d}/{:5d}, loss:{:5.3f}, per step time:{:5.3f} ms" .format(self.cur_epoch, self.epochs, self.cur_step, self.dataset_size, step_loss, per_step_mseconds), flush=True) if epoch_end_f and self.save_checkpoint: save_checkpoint( cb_params.train_network, os.path.join(self.save_checkpoint_path, f"naml_{self.cur_epoch-1}.ckpt"))
def save_checkpoint(state, check_list, log_dir, epoch=0): check_file = os.path.join(log_dir, 'model_{}.ckpt'.format(epoch)) mindspore.save_checkpoint(state, check_file) # MSP: 潜在问题 check_list.write('model_{}.ckpt\n'.format(epoch))
print("========== The Training Model is Defined. ==========") # train the model and export the encrypted CheckPoint file through Callback config_ck = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10, enc_key=b'0123456789ABCDEF', enc_mode='AES-GCM') ckpoint_cb = ModelCheckpoint(prefix='lenet_enc', directory=None, config=config_ck) model.train(10, train_dataset, dataset_sink_mode=False, callbacks=[ckpoint_cb, LossMonitor(1875)]) acc = model.eval(eval_dataset, dataset_sink_mode=False) print("Accuracy: {}".format(acc["Accuracy"])) # export the encrypted CheckPoint file through save_checkpoint save_checkpoint(network, 'lenet_enc.ckpt', enc_key=b'0123456789ABCDEF', enc_mode='AES-GCM') # load encrypted CheckPoint file and eval param_dict = load_checkpoint('lenet_enc-10_1875.ckpt', dec_key=b'0123456789ABCDEF', dec_mode='AES-GCM') load_param_into_net(network, param_dict) acc = model.eval(eval_dataset, dataset_sink_mode=False) print("Accuracy loading encrypted CheckPoint: {}".format(acc["Accuracy"]))