def single_agent():
    config = load_config()
    # num_agents = config['num_agents']
    torch.set_num_threads(1)

    env = GymEnv(config=config)
    env.reset()

    net = ActorCritic(True, config)
    net.ActorNetwork.init_params()
    net.CriticNetwork.init_params()

    bwe = config['sending_rate'][config['default_bwe']]

    i = 1
    s_batch = []
    r_batch = []
    a_batch = []

    # experience RTC if not forced to stop
    ax = []
    ay = []
    plt.ion()
    while True:
        # todo: Agent interact with gym
        state, reward, done, _ = env.step(bwe)

        r_batch.append(reward)

        action = net.predict(state)
        bwe = config['sending_rate'][action]
        a_batch.append(action)
        s_batch.append(state)

        # todo: need to be fixed
        if done:
            action = config['default_bwe']
            bwe = config['sending_rate'][action]
            # update network
            net.getNetworkGradient(s_batch, a_batch, r_batch, done)
            net.updateNetwork()
            print('Network update.')

            i += 1
            ax.append(i)
            # ay.append(entropy)
            ay.append(reward)
            plt.clf()
            plt.plot(ax, ay)
            plt.pause(0.1)
            # s_batch.append(np.zeros(config['state_dim'], config['state_length']))
            # a_batch.append(action)
            env.reset()
            print('Environment has been reset.')
            print('Epoch {}, Reward: {}'.format(i - 1, reward))
        if i % 100 == 0:
            # print('Current BWE: ' + str(bwe))
            torch.save(net.ActorNetwork.state_dict(),
                       config['model_dir'] + '/actor1_{}.pt'.format(str(i)))
            torch.save(net.CriticNetwork.state_dict(),
                       config['model_dir'] + '/critic13m_{}.pt'.format(str(i)))
            print('Model Restored.')
def central_agent(net_params_queue, exp_queues, config):
    torch.set_num_threads(1)

    # log training info
    logging.basicConfig(filename=config['log_dir'] +
                        '/Central_agent_training.log',
                        filemode='w',
                        level=logging.INFO)

    assert len(net_params_queue) == config['num_agents']
    assert len(exp_queues) == config['num_agents']

    net = ActorCritic(True, config)

    # since the original pensieve does not use critic in workers
    # push actor_net_params into net_params_queue only, and save parameters regarding both networks separately
    if config['load_model']:
        actor_net_params = torch.load(config['model_dir'] +
                                      '/actor_300k1_80.pt')
        critic_net_params = torch.load(config['model_dir'] +
                                       '/critic_300k1_80.pt')
        net.ActorNetwork.load_state_dict(actor_net_params)
        net.CriticNetwork.load_state_dict(critic_net_params)
    else:
        net.ActorNetwork.init_params()
        net.CriticNetwork.init_params()
    #
    actor_net_params = list(net.ActorNetwork.parameters())
    for i in range(config['num_agents']):
        # actor_net_params = net.ActorNetwork.parameters()
        net_params_queue[i].put(actor_net_params)

    epoch = 0
    total_reward = 0.0
    total_batch_len = 0.0
    episode_entropy = 0.0
    ax = []
    ay = []
    plt.ion()

    while True:
        start = time.time()
        actor_net_params = list(net.ActorNetwork.parameters())
        for i in range(config['num_agents']):
            net_params_queue[i].put(actor_net_params)

        for i in range(config['num_agents']):
            s_batch, a_batch, r_batch, done, e_batch = exp_queues[i].get()

            net.getNetworkGradient(s_batch, a_batch, r_batch, done)

            total_reward += np.sum(r_batch)
            total_batch_len += len(r_batch)
            episode_entropy += np.sum(e_batch)

        net.updateNetwork()
        epoch += 1
        avg_reward = total_reward / total_batch_len
        # avg_entropy = total_entropy / total_batch_len

        logging.info('Epoch ' + str(epoch) + '\nAverage reward: ' +
                     str(avg_reward) + '\nEpisode entropy: ' +
                     str(episode_entropy))
        ax.append(epoch)
        ay.append(episode_entropy)
        plt.clf()
        plt.plot(ax, ay)
        plt.pause(0.1)

        total_reward = 0.0
        total_batch_len = 0
        episode_entropy = 0.0

        if epoch % config['save_interval'] == 0:
            print('Train Epoch ' + str(epoch) + ', Model restored.')
            print('Epoch costs ' + str(time.time() - start) + ' seconds.')
            torch.save(
                net.ActorNetwork.state_dict(),
                config['model_dir'] + '/actor_300k_' + str(epoch) + '.pt')
            torch.save(
                net.CriticNetwork.state_dict(),
                config['model_dir'] + '/critic_300k_' + str(epoch) + '.pt')