def test(config): ds = test_dataset_creator() net = CNNCTC_Model(config.NUM_CLASS, config.HIDDEN_SIZE, config.FINAL_FEATURE_WIDTH) ckpt_path = config.CKPT_PATH param_dict = load_checkpoint(ckpt_path) load_param_into_net(net, param_dict) print('parameters loaded! from: ', ckpt_path) converter = CTCLabelConverter(config.CHARACTER) model_run_time = AverageMeter() npu_to_cpu_time = AverageMeter() postprocess_time = AverageMeter() count = 0 correct_count = 0 for data in ds.create_tuple_iterator(): img, _, text, _, length = data img_tensor = Tensor(img, mstype.float32) model_run_begin = time.time() model_predict = net(img_tensor) model_run_end = time.time() model_run_time.update(model_run_end - model_run_begin) npu_to_cpu_begin = time.time() model_predict = np.squeeze(model_predict.asnumpy()) npu_to_cpu_end = time.time() npu_to_cpu_time.update(npu_to_cpu_end - npu_to_cpu_begin) postprocess_begin = time.time() preds_size = np.array([model_predict.shape[1]] * config.TEST_BATCH_SIZE) preds_index = np.argmax(model_predict, 2) preds_index = np.reshape(preds_index, [-1]) preds_str = converter.decode(preds_index, preds_size) postprocess_end = time.time() postprocess_time.update(postprocess_end - postprocess_begin) label_str = converter.reverse_encode(text.asnumpy(), length.asnumpy()) if count == 0: model_run_time.reset() npu_to_cpu_time.reset() postprocess_time.reset() else: print('---------model run time--------', model_run_time.avg) print('---------npu_to_cpu run time--------', npu_to_cpu_time.avg) print('---------postprocess run time--------', postprocess_time.avg) print("Prediction samples: \n", preds_str[:5]) print("Ground truth: \n", label_str[:5]) for pred, label in zip(preds_str, label_str): if pred == label: correct_count += 1 count += 1 print('accuracy: ', correct_count / count)
def train(args_opt, config): if args_opt.run_distribute: init() context.set_auto_parallel_context(parallel_mode="data_parallel") ds = dataset_creator(args_opt.run_distribute) net = CNNCTC_Model(config.NUM_CLASS, config.HIDDEN_SIZE, config.FINAL_FEATURE_WIDTH) net.set_train(True) if config.CKPT_PATH != '': param_dict = load_checkpoint(config.CKPT_PATH) load_param_into_net(net, param_dict) print('parameters loaded!') else: print('train from scratch...') criterion = ctc_loss() opt = mindspore.nn.RMSProp(params=net.trainable_params(), centered=True, learning_rate=config.LR_PARA, momentum=config.MOMENTUM, loss_scale=config.LOSS_SCALE) net = WithLossCell(net, criterion) loss_scale_manager = mindspore.train.loss_scale_manager.FixedLossScaleManager( config.LOSS_SCALE, False) model = Model(net, optimizer=opt, loss_scale_manager=loss_scale_manager, amp_level="O2") callback = LossCallBack() config_ck = CheckpointConfig( save_checkpoint_steps=config.SAVE_CKPT_PER_N_STEP, keep_checkpoint_max=config.KEEP_CKPT_MAX_NUM) ckpoint_cb = ModelCheckpoint(prefix="CNNCTC", config=config_ck, directory=config.SAVE_PATH) if args_opt.run_distribute: if args_opt.device_id == 0: model.train(config.TRAIN_EPOCHS, ds, callbacks=[callback, ckpoint_cb], dataset_sink_mode=False) else: model.train(config.TRAIN_EPOCHS, ds, callbacks=[callback], dataset_sink_mode=False) else: model.train(config.TRAIN_EPOCHS, ds, callbacks=[callback, ckpoint_cb], dataset_sink_mode=False)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") parser = argparse.ArgumentParser(description='CNNCTC_export') parser.add_argument('--ckpt_file', type=str, default='./ckpts/cnn_ctc.ckpt', help='CNN&CTC ckpt file.') parser.add_argument('--output_file', type=str, default='cnn_ctc', help='CNN&CTC output air name.') args_opt = parser.parse_args() if __name__ == '__main__': cfg = Config_CNNCTC() ckpt_path = cfg.CKPT_PATH if args_opt.ckpt_file != "": ckpt_path = args_opt.ckpt_file net = CNNCTC_Model(cfg.NUM_CLASS, cfg.HIDDEN_SIZE, cfg.FINAL_FEATURE_WIDTH) load_checkpoint(ckpt_path, net=net) bs = cfg.TEST_BATCH_SIZE input_data = Tensor(np.zeros([bs, 3, cfg.IMG_H, cfg.IMG_W]), mstype.float32) export(net, input_data, file_name=args_opt.output_file, file_format="AIR")
def create_network(name, *args, **kwargs): if name == "cnnctc": return CNNCTC_Model(*args, **kwargs) raise NotImplementedError(f"{name} is not implemented in the repo")
def cnnctc(*args, **kwargs): return CNNCTC_Model(*args, **kwargs)