示例#1
0
def test(config):
    ds = test_dataset_creator()

    net = CNNCTC_Model(config.NUM_CLASS, config.HIDDEN_SIZE, config.FINAL_FEATURE_WIDTH)

    ckpt_path = config.CKPT_PATH
    param_dict = load_checkpoint(ckpt_path)
    load_param_into_net(net, param_dict)
    print('parameters loaded! from: ', ckpt_path)

    converter = CTCLabelConverter(config.CHARACTER)

    model_run_time = AverageMeter()
    npu_to_cpu_time = AverageMeter()
    postprocess_time = AverageMeter()

    count = 0
    correct_count = 0
    for data in ds.create_tuple_iterator():
        img, _, text, _, length = data

        img_tensor = Tensor(img, mstype.float32)

        model_run_begin = time.time()
        model_predict = net(img_tensor)
        model_run_end = time.time()
        model_run_time.update(model_run_end - model_run_begin)

        npu_to_cpu_begin = time.time()
        model_predict = np.squeeze(model_predict.asnumpy())
        npu_to_cpu_end = time.time()
        npu_to_cpu_time.update(npu_to_cpu_end - npu_to_cpu_begin)

        postprocess_begin = time.time()
        preds_size = np.array([model_predict.shape[1]] * config.TEST_BATCH_SIZE)
        preds_index = np.argmax(model_predict, 2)
        preds_index = np.reshape(preds_index, [-1])
        preds_str = converter.decode(preds_index, preds_size)
        postprocess_end = time.time()
        postprocess_time.update(postprocess_end - postprocess_begin)

        label_str = converter.reverse_encode(text.asnumpy(), length.asnumpy())

        if count == 0:
            model_run_time.reset()
            npu_to_cpu_time.reset()
            postprocess_time.reset()
        else:
            print('---------model run time--------', model_run_time.avg)
            print('---------npu_to_cpu run time--------', npu_to_cpu_time.avg)
            print('---------postprocess run time--------', postprocess_time.avg)

        print("Prediction samples: \n", preds_str[:5])
        print("Ground truth: \n", label_str[:5])
        for pred, label in zip(preds_str, label_str):
            if pred == label:
                correct_count += 1
            count += 1

    print('accuracy: ', correct_count / count)
示例#2
0
文件: train.py 项目: yrpang/mindspore
def train(args_opt, config):
    if args_opt.run_distribute:
        init()
        context.set_auto_parallel_context(parallel_mode="data_parallel")

    ds = dataset_creator(args_opt.run_distribute)

    net = CNNCTC_Model(config.NUM_CLASS, config.HIDDEN_SIZE,
                       config.FINAL_FEATURE_WIDTH)
    net.set_train(True)

    if config.CKPT_PATH != '':
        param_dict = load_checkpoint(config.CKPT_PATH)
        load_param_into_net(net, param_dict)
        print('parameters loaded!')
    else:
        print('train from scratch...')

    criterion = ctc_loss()
    opt = mindspore.nn.RMSProp(params=net.trainable_params(),
                               centered=True,
                               learning_rate=config.LR_PARA,
                               momentum=config.MOMENTUM,
                               loss_scale=config.LOSS_SCALE)

    net = WithLossCell(net, criterion)
    loss_scale_manager = mindspore.train.loss_scale_manager.FixedLossScaleManager(
        config.LOSS_SCALE, False)
    model = Model(net,
                  optimizer=opt,
                  loss_scale_manager=loss_scale_manager,
                  amp_level="O2")

    callback = LossCallBack()
    config_ck = CheckpointConfig(
        save_checkpoint_steps=config.SAVE_CKPT_PER_N_STEP,
        keep_checkpoint_max=config.KEEP_CKPT_MAX_NUM)
    ckpoint_cb = ModelCheckpoint(prefix="CNNCTC",
                                 config=config_ck,
                                 directory=config.SAVE_PATH)

    if args_opt.run_distribute:
        if args_opt.device_id == 0:
            model.train(config.TRAIN_EPOCHS,
                        ds,
                        callbacks=[callback, ckpoint_cb],
                        dataset_sink_mode=False)
        else:
            model.train(config.TRAIN_EPOCHS,
                        ds,
                        callbacks=[callback],
                        dataset_sink_mode=False)
    else:
        model.train(config.TRAIN_EPOCHS,
                    ds,
                    callbacks=[callback, ckpoint_cb],
                    dataset_sink_mode=False)
示例#3
0
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")

parser = argparse.ArgumentParser(description='CNNCTC_export')
parser.add_argument('--ckpt_file',
                    type=str,
                    default='./ckpts/cnn_ctc.ckpt',
                    help='CNN&CTC ckpt file.')
parser.add_argument('--output_file',
                    type=str,
                    default='cnn_ctc',
                    help='CNN&CTC output air name.')
args_opt = parser.parse_args()

if __name__ == '__main__':
    cfg = Config_CNNCTC()
    ckpt_path = cfg.CKPT_PATH

    if args_opt.ckpt_file != "":
        ckpt_path = args_opt.ckpt_file

    net = CNNCTC_Model(cfg.NUM_CLASS, cfg.HIDDEN_SIZE, cfg.FINAL_FEATURE_WIDTH)

    load_checkpoint(ckpt_path, net=net)

    bs = cfg.TEST_BATCH_SIZE

    input_data = Tensor(np.zeros([bs, 3, cfg.IMG_H, cfg.IMG_W]),
                        mstype.float32)

    export(net, input_data, file_name=args_opt.output_file, file_format="AIR")
示例#4
0
def create_network(name, *args, **kwargs):
    if name == "cnnctc":
        return CNNCTC_Model(*args, **kwargs)
    raise NotImplementedError(f"{name} is not implemented in the repo")
示例#5
0
def cnnctc(*args, **kwargs):
    return CNNCTC_Model(*args, **kwargs)