コード例 #1
0
def create_network(name, *args, **kwargs):
    if name == "textrcnn":
        return textrcnn(*args, **kwargs)
    raise NotImplementedError(f"{name} is not implemented in the repo")
コード例 #2
0
ファイル: eval.py プロジェクト: yrpang/mindspore
set_seed(1)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='textrcnn')
    parser.add_argument('--ckpt_path', type=str)
    args = parser.parse_args()
    context.set_context(mode=context.GRAPH_MODE,
                        save_graphs=False,
                        device_target="Ascend")

    device_id = int(os.getenv('DEVICE_ID'))
    context.set_context(device_id=device_id)

    embedding_table = np.loadtxt(
        os.path.join(cfg.preprocess_path, "weight.txt")).astype(np.float32)
    network = textrcnn(weight=Tensor(embedding_table),
                       vocab_size=embedding_table.shape[0],
                       cell=cfg.cell,
                       batch_size=cfg.batch_size)
    loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
    loss_cb = LossMonitor()
    print("============== Starting Testing ==============")
    ds_eval = create_dataset(cfg.preprocess_path, cfg.batch_size, False)
    param_dict = load_checkpoint(args.ckpt_path)
    load_param_into_net(network, param_dict)
    network.set_train(False)
    model = Model(network, loss, metrics={'acc': Accuracy()}, amp_level='O3')
    acc = model.eval(ds_eval, dataset_sink_mode=False)
    print("============== Accuracy:{} ==============".format(acc))
コード例 #3
0
def textrcnn_net(*args, **kwargs):
    return textrcnn(*args, **kwargs)