def fix_ckpt(self, ckpt_path, new_ckpt_path):
        """Fix checkpoint file."""
        param_dict = load_checkpoint(ckpt_path)

        main_module_name = self._fixed_mapper['main_module_name']
        fixed_variable_dict = self._fixed_mapper['fixed_variable_mapper']
        fixed_module_dict = self._fixed_mapper['fixed_module_mapper']

        save_obj = list()
        for weight_name, weight_value in param_dict.items():
            weight_name_scopes = weight_name.split('.')
            weight_name_scopes.insert(0, main_module_name)
            for idx, w in enumerate(weight_name_scopes[:-1]):
                for fixed_variable_module, fixed_variable_name_mapper in fixed_variable_dict.items(
                ):
                    if re.match(
                            fixed_variable_module,
                            fixed_module_dict.get('_'.join(w.split('_')[:-1]),
                                                  w)):
                        weight_name = weight_name.replace(
                            weight_name_scopes[idx + 1],
                            fixed_variable_name_mapper.get(
                                weight_name_scopes[idx + 1],
                                weight_name_scopes[idx + 1]))

            obj = {'name': weight_name, 'data': Tensor(weight_value)}
            save_obj.append(obj)

        save_checkpoint(save_obj, new_ckpt_path)
        logger_console.info(f'Saved new checkpoint file to {new_ckpt_path}.')
Example #2
0
    def epoch_end(self, run_context):
        if self.step_eval:
            return None
        flag = ''
        cb_params = run_context.original_args()
        cur_epoch = cb_params.cur_epoch_num
        if cur_epoch > 0 and cur_epoch % self.eval_epoch == 0:
            acc = self.eval()

            if acc > self.best_acc:
                flag = '↑'
                self.best_acc = acc
                save_checkpoint(self.model._network, self.best_ckpt)
                self.logger.update_acc_ckpt(acc, self.best_ckpt)

            else:
                if acc > self.threshold:
                    self.patience_count += 1
                    if self.patience_count > self.patience:
                        param_dict = load_checkpoint(
                            ckpt_file_name=self.best_ckpt,
                            net=self.model._network)
                        load_param_into_net(net=self.model._network,
                                            parameter_dict=param_dict)
                        self.patience_count = 0

            print(
                f'* acc for epoch: {cur_epoch} is {acc * 100}%{flag}, best acc is {self.best_acc * 100}%'
            )
Example #3
0
def convert_weights(
    model_type: str,
    from_model: str,
    from_path: str,
    config_path: str,
    to_model: str,
    dump_path: str,
):
    model_class = MODEL_CLASSES[to_model][model_type]
    config = ConfigBase(config_path)
    model = model_class(config)
    load_weights_fct = LOAD_WEIGHTS_MAPS[to_model][model_type][from_model]
    if to_model == "tf":
        input_ids = tf.ones([3, 4], dtype=tf.int32)
        model(input_ids)
    load_weights_fct(model, config, from_path)

    if to_model == "pt":
        torch.save(model.state_dict(), dump_path)
    elif to_model == "tf":
        model.save_weights(dump_path)
    elif to_model == "ms":
        mindspore.save_checkpoint(model, dump_path)
    elif to_model == "of":
        flow.save(model.state_dict(), dump_path)
    elif to_model == "pd":
        paddle.save(model.state_dict(), dump_path)
    print("Save {} model to {}".format(to_model, dump_path))
Example #4
0
 def end(self, run_context):
     cb_params = run_context.original_args()
     print("epoch: {:3d}/{:3d}, avg loss:{:5.3f}".format(
         self.epochs, self.epochs, np.mean(self.losses)),
           flush=True)
     if self.save_checkpoint:
         save_checkpoint(
             cb_params.train_network,
             os.path.join(self.save_checkpoint_path, f"naml_last.ckpt"))
Example #5
0
 def step_end(self, run_context):
     """step end"""
     cb_params = run_context.original_args()
     result = self.model.eval(self.ds_eval)
     if result['Accuracy'] > self.acc:
         self.acc = result['Accuracy']
         file_name = str(self.acc) + ".ckpt"
         save_checkpoint(save_obj=cb_params.train_network, ckpt_file_name=file_name)
         print("Save the maximum accuracy checkpoint, the accuracy is", self.acc)
Example #6
0
def csd_train(train_loader, net, opt):
    set_seed(1)
    device_id = int(os.getenv('DEVICE_ID', '0'))
    print("[CSD] Start Training...")

    step_size = train_loader.get_dataset_size()
    lr = []
    for i in range(0, opt.epochs):
        cur_lr = opt.lr / (2 ** ((i + 1) // 200))
        lr.extend([cur_lr] * step_size)
    optim = nn.Adam(net.trainable_params(), learning_rate=lr, loss_scale=opt.loss_scale)

    # net_with_loss = NetWithLossCell(net)
    net_with_loss = NetWithCSDLossCell(net, args.contra_lambda, args.neg_num)
    train_cell = TrainOneStepCell(net_with_loss, optim)
    net.set_train()
    eval_net = net

    # time_cb = TimeMonitor(data_size=step_size)
    # loss_cb = LossMonitor()
    # metrics = {
    #     "psnr": PSNR(rgb_range=opt.rgb_range, shave=True),
    # }
    # eval_cb = EvalCallBack(eval_net, eval_ds, args.test_every, step_size / opt.batch_size, metrics=metrics,
    #                        rank_id=rank_id)
    # cb = [time_cb, loss_cb]
    # config_ck = CheckpointConfig(save_checkpoint_steps=opt.ckpt_save_interval * step_size,
    #                              keep_checkpoint_max=opt.ckpt_save_max)
    # ckpt_cb = ModelCheckpoint(prefix=opt.filename, directory=opt.ckpt_save_path, config=config_ck)
    # if device_id == 0:
        # cb += [ckpt_cb]

    for epoch in range(0, opt.epochs):
        epoch_loss = 0
        for iteration, batch in enumerate(train_loader.create_dict_iterator(), 1):
            lr = batch["LR"]
            hr = batch["HR"]

            loss = train_cell(lr, hr, Tensor(opt.stu_width_mult), Tensor(1.0))
            epoch_loss += loss

        print(f"Epoch[{epoch}] loss: {epoch_loss.asnumpy()}")
        # with eval_net.set_train(False):
        #     do_eval(eval_ds, eval_net)

        if (epoch) % 10 == 0:
            print('===> Saving model...')
            save_checkpoint(net, f'./ckpt/{opt.filename}.ckpt')
Example #7
0
 def epoch_end(self, run_context):
     """Callback when epoch end."""
     cb_params = run_context.original_args()
     cur_epoch = cb_params.cur_epoch_num
     if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
         eval_start = time.time()
         res = self.eval_function(self.eval_param_dict)
         eval_cost = time.time() - eval_start
         print("epoch: {}, {}: {}, eval_cost:{:.2f}".format(cur_epoch, self.metrics_name, res, eval_cost),
               flush=True)
         if res >= self.best_res:
             self.best_res = res
             self.best_epoch = cur_epoch
             print("update best result: {}".format(res), flush=True)
             if self.save_best_ckpt:
                 if os.path.exists(self.bast_ckpt_path):
                     self.remove_ckpoint_file(self.bast_ckpt_path)
                 save_checkpoint(cb_params.train_network, self.bast_ckpt_path)
                 print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True)
Example #8
0
    def save_model(self, filename=None):
        """Saving the model to a file."""
        model_name = Config().trainer.model_name
        model_dir = Config().params['model_dir']

        if not os.path.exists(model_dir):
            os.makedirs(model_dir)

        if filename is not None:
            model_path = f'{model_dir}{filename}'
        else:
            model_path = f'{model_dir}{model_name}.ckpt'

        mindspore.save_checkpoint(self.model, model_path)

        if self.client_id == 0:
            logging.info("[Server #%d] Model saved to %s.", os.getpid(),
                         model_path)
        else:
            logging.info("[Client #%d] Model saved to %s.", self.client_id,
                         model_path)
Example #9
0
    def train_process(self, epoch, train_dataset, mini_steps=None):
        """
        Training process. The data would be passed to network directly.
        """
        dataset_helper = DatasetHelper(train_dataset,
                                       dataset_sink_mode=False,
                                       epoch_num=epoch)

        for i in range(epoch):
            step = 0
            for k, next_element in enumerate(dataset_helper):
                loss = self._train_forward_backward(*next_element)
                if (k + 1) % mini_steps == 0:
                    step += 1
                    print("epoch:", i + 1, "step:", step, "loss is ", loss)
                    self._train_optim()
                    self._train_clear()

            train_dataset.reset()

        save_checkpoint(self._train_forward_backward,
                        "gradient_accumulation.ckpt")
Example #10
0
 def epoch_end(self, run_context):
     """Callback when epoch end."""
     epoch_end_f = True
     if self.sink_mode:
         self.cur_step += self.epoch_steps
         epoch_end_f = False
         if self.cur_step >= self.dataset_size:
             epoch_end_f = True
             self.cur_step = self.cur_step % self.dataset_size
         cb_params = run_context.original_args()
         epoch_mseconds = (time.time() - self.epoch_time) * 1000
         per_step_mseconds = epoch_mseconds / cb_params.batch_num
         step_loss = cb_params.net_outputs
         if isinstance(step_loss,
                       (tuple, list)) and isinstance(step_loss[0], Tensor):
             step_loss = step_loss[0]
         if isinstance(step_loss, Tensor):
             step_loss = np.mean(step_loss.asnumpy())
         self.losses.append(step_loss)
     if epoch_end_f:
         print("epoch: {:3d}/{:3d}, avg loss:{:5.3f}".format(
             self.cur_epoch, self.epochs, np.mean(self.losses)),
               flush=True)
         self.losses = []
         self.cur_epoch += 1
     if self.sink_mode:
         print(
             "epoch: {:3d}/{:3d}, step:{:5d}/{:5d}, loss:{:5.3f}, per step time:{:5.3f} ms"
             .format(self.cur_epoch, self.epochs, self.cur_step,
                     self.dataset_size, step_loss, per_step_mseconds),
             flush=True)
     if epoch_end_f and self.save_checkpoint:
         save_checkpoint(
             cb_params.train_network,
             os.path.join(self.save_checkpoint_path,
                          f"naml_{self.cur_epoch-1}.ckpt"))
Example #11
0
def save_checkpoint(state, check_list, log_dir, epoch=0):
    check_file = os.path.join(log_dir, 'model_{}.ckpt'.format(epoch))
    mindspore.save_checkpoint(state, check_file)  # MSP: 潜在问题
    check_list.write('model_{}.ckpt\n'.format(epoch))
Example #12
0
    print("========== The Training Model is Defined. ==========")

    # train the model and export the encrypted CheckPoint file through Callback
    config_ck = CheckpointConfig(save_checkpoint_steps=1875,
                                 keep_checkpoint_max=10,
                                 enc_key=b'0123456789ABCDEF',
                                 enc_mode='AES-GCM')
    ckpoint_cb = ModelCheckpoint(prefix='lenet_enc',
                                 directory=None,
                                 config=config_ck)
    model.train(10,
                train_dataset,
                dataset_sink_mode=False,
                callbacks=[ckpoint_cb, LossMonitor(1875)])
    acc = model.eval(eval_dataset, dataset_sink_mode=False)
    print("Accuracy: {}".format(acc["Accuracy"]))

    # export the encrypted CheckPoint file through save_checkpoint
    save_checkpoint(network,
                    'lenet_enc.ckpt',
                    enc_key=b'0123456789ABCDEF',
                    enc_mode='AES-GCM')

    # load encrypted CheckPoint file and eval
    param_dict = load_checkpoint('lenet_enc-10_1875.ckpt',
                                 dec_key=b'0123456789ABCDEF',
                                 dec_mode='AES-GCM')
    load_param_into_net(network, param_dict)
    acc = model.eval(eval_dataset, dataset_sink_mode=False)
    print("Accuracy loading encrypted CheckPoint: {}".format(acc["Accuracy"]))