def single_agent(): config = load_config() # num_agents = config['num_agents'] torch.set_num_threads(1) env = GymEnv(config=config) env.reset() net = ActorCritic(True, config) net.ActorNetwork.init_params() net.CriticNetwork.init_params() bwe = config['sending_rate'][config['default_bwe']] i = 1 s_batch = [] r_batch = [] a_batch = [] # experience RTC if not forced to stop ax = [] ay = [] plt.ion() while True: # todo: Agent interact with gym state, reward, done, _ = env.step(bwe) r_batch.append(reward) action = net.predict(state) bwe = config['sending_rate'][action] a_batch.append(action) s_batch.append(state) # todo: need to be fixed if done: action = config['default_bwe'] bwe = config['sending_rate'][action] # update network net.getNetworkGradient(s_batch, a_batch, r_batch, done) net.updateNetwork() print('Network update.') i += 1 ax.append(i) # ay.append(entropy) ay.append(reward) plt.clf() plt.plot(ax, ay) plt.pause(0.1) # s_batch.append(np.zeros(config['state_dim'], config['state_length'])) # a_batch.append(action) env.reset() print('Environment has been reset.') print('Epoch {}, Reward: {}'.format(i - 1, reward)) if i % 100 == 0: # print('Current BWE: ' + str(bwe)) torch.save(net.ActorNetwork.state_dict(), config['model_dir'] + '/actor1_{}.pt'.format(str(i))) torch.save(net.CriticNetwork.state_dict(), config['model_dir'] + '/critic13m_{}.pt'.format(str(i))) print('Model Restored.')
def central_agent(net_params_queue, exp_queues, config): torch.set_num_threads(1) # log training info logging.basicConfig(filename=config['log_dir'] + '/Central_agent_training.log', filemode='w', level=logging.INFO) assert len(net_params_queue) == config['num_agents'] assert len(exp_queues) == config['num_agents'] net = ActorCritic(True, config) # since the original pensieve does not use critic in workers # push actor_net_params into net_params_queue only, and save parameters regarding both networks separately if config['load_model']: actor_net_params = torch.load(config['model_dir'] + '/actor_300k1_80.pt') critic_net_params = torch.load(config['model_dir'] + '/critic_300k1_80.pt') net.ActorNetwork.load_state_dict(actor_net_params) net.CriticNetwork.load_state_dict(critic_net_params) else: net.ActorNetwork.init_params() net.CriticNetwork.init_params() # actor_net_params = list(net.ActorNetwork.parameters()) for i in range(config['num_agents']): # actor_net_params = net.ActorNetwork.parameters() net_params_queue[i].put(actor_net_params) epoch = 0 total_reward = 0.0 total_batch_len = 0.0 episode_entropy = 0.0 ax = [] ay = [] plt.ion() while True: start = time.time() actor_net_params = list(net.ActorNetwork.parameters()) for i in range(config['num_agents']): net_params_queue[i].put(actor_net_params) for i in range(config['num_agents']): s_batch, a_batch, r_batch, done, e_batch = exp_queues[i].get() net.getNetworkGradient(s_batch, a_batch, r_batch, done) total_reward += np.sum(r_batch) total_batch_len += len(r_batch) episode_entropy += np.sum(e_batch) net.updateNetwork() epoch += 1 avg_reward = total_reward / total_batch_len # avg_entropy = total_entropy / total_batch_len logging.info('Epoch ' + str(epoch) + '\nAverage reward: ' + str(avg_reward) + '\nEpisode entropy: ' + str(episode_entropy)) ax.append(epoch) ay.append(episode_entropy) plt.clf() plt.plot(ax, ay) plt.pause(0.1) total_reward = 0.0 total_batch_len = 0 episode_entropy = 0.0 if epoch % config['save_interval'] == 0: print('Train Epoch ' + str(epoch) + ', Model restored.') print('Epoch costs ' + str(time.time() - start) + ' seconds.') torch.save( net.ActorNetwork.state_dict(), config['model_dir'] + '/actor_300k_' + str(epoch) + '.pt') torch.save( net.CriticNetwork.state_dict(), config['model_dir'] + '/critic_300k_' + str(epoch) + '.pt')