예제 #1
0
파일: eval.py 프로젝트: yrpang/mindspore
    cfg.action_space_dim = env.action_space.n
    agent = Agent(**cfg)

    # load checkpoint
    if args.ckpt_path:
        param_dict = load_checkpoint(args.ckpt_path)
        not_load_param = load_param_into_net(agent.policy_net, param_dict)
        if not_load_param:
            raise ValueError("Load param into net fail!")

    score = 0
    agent.load_dict()
    for episode in range(50):
        s0 = env.reset()
        total_reward = 1
        while True:
            a0 = agent.eval_act(s0)
            s1, r1, done, _ = env.step(a0)

            if done:
                r1 = -1

            if done:
                break

            total_reward += r1
            s0 = s1
        score += total_reward
        print("episode", episode, "total_reward", total_reward)
    print("mean_reward", score / 50)