# 创建策略网络 net = ActorCritic(input_dim=env.state_dims, output_dim=env.action_dims).to(device) if len(glob.glob(os.path.join(ckpt_folder, '*.pt'))) > 0: # 加载最后一个模型 last_ckpt = sorted(glob.glob(os.path.join(ckpt_folder, '*.pt')))[-1] print("load checkpoint", last_ckpt) checkpoint = torch.load(last_ckpt,map_location=device) net.load_state_dict(checkpoint['model_G_state_dict']) last_episode_id = checkpoint['episode_id'] REWARDS = checkpoint['REWARDS'] for episode_id in range(last_episode_id, max_m_episode): # 循环一局训练 state = env.reset() rewards, log_probs, values, masks = [], [], [], [] for step_id in range(max_steps): # 根据策略网络计算动作和对数概率,并获取动作值 action, log_prob, value = net.get_action(state) # 执行动作并获取下一个状态和奖励 state, reward, done = env.step(action) rewards.append(reward) log_probs.append(log_prob) values.append(value) masks.append(1-done) if episode_id % 1000 == 1: env.render() if done or step_id == max_steps-1: