def main(_):
    data = json.load(open('config.json'))
    data = data['apex']
    print(data)

    local_job_device = f'/job:{FLAGS.job_name}/task:{FLAGS.task}'
    shared_job_device = '/job:learner/task:0'
    is_learner = FLAGS.job_name == 'learner'

    cluster = tf.train.ClusterSpec({
        'actor': [
            'localhost:{}'.format(data['server_port'] + 1 + i)
            for i in range(data['num_actors'])
        ],
        'learner': ['{}:{}'.format(data['server_ip'], data['server_port'])]
    })

    server = tf.train.Server(cluster,
                             job_name=FLAGS.job_name,
                             task_index=FLAGS.task)

    with tf.device(shared_job_device):
        with tf.device('/cpu'):
            apex_queue = buffer_queue.ApexFIFOQueue(
                trajectory=data['trajectory'],
                input_shape=data['model_input'],
                output_size=data['model_output'],
                queue_size=data['queue_size'],
                batch_size=data['batch_size'],
                num_actors=data['num_actors'])

        learner = apex.Agent(input_shape=data['model_input'],
                             num_action=data['model_output'],
                             discount_factor=data['discount_factor'],
                             gradient_clip_norm=data['gradient_clip_norm'],
                             reward_clipping=data['reward_clipping'],
                             start_learning_rate=data['start_learning_rate'],
                             end_learning_rate=data['end_learning_rate'],
                             learning_frame=data['learning_frame'],
                             model_name='learner',
                             learner_name='learner')

    with tf.device(local_job_device):

        actor = apex.Agent(input_shape=data['model_input'],
                           num_action=data['model_output'],
                           discount_factor=data['discount_factor'],
                           gradient_clip_norm=data['gradient_clip_norm'],
                           reward_clipping=data['reward_clipping'],
                           start_learning_rate=data['start_learning_rate'],
                           end_learning_rate=data['end_learning_rate'],
                           learning_frame=data['learning_frame'],
                           model_name=f'actor_{FLAGS.task}',
                           learner_name='learner')

    sess = tf.Session(server.target)
    apex_queue.set_session(sess)
    learner.set_session(sess)

    if not is_learner:
        actor.set_session(sess)

    if is_learner:
        import time

        learner.target_to_main()
        replay_buffer = buffer_queue.Memory(capacity=int(1e5))
        buffer_step = 0
        train_step = 0

        writer = SummaryWriter('runs/learner')

        while True:
            print(f'train step : {train_step} | buffer step : {buffer_step}')
            if apex_queue.get_size():

                buffer_step += 1

                from_actor = apex_queue.sample_batch(1)
                from_actor_state = from_actor.state[0]
                from_actor_next_state = from_actor.next_state[0]
                from_actor_previous_action = from_actor.previous_action[0]
                from_actor_action = from_actor.action[0]
                from_actor_reward = from_actor.reward[0]
                from_actor_done = from_actor.done[0]

                td_error = learner.get_td_error(
                    state=from_actor_state,
                    next_state=from_actor_next_state,
                    previous_action=from_actor_previous_action,
                    action=from_actor_action,
                    reward=from_actor_reward,
                    done=from_actor_done)

                for i in range(len(td_error)):
                    replay_buffer.add(td_error[i], [
                        from_actor_state[i], from_actor_next_state[i],
                        from_actor_previous_action[i], from_actor_action[i],
                        from_actor_reward[i], from_actor_done[i]
                    ])

            if buffer_step > 10:
                train_step += 1

                s = time.time()

                minibatch, idxs, is_weight = replay_buffer.sample(
                    data['batch_size'])
                minibatch = np.array(minibatch)

                state = np.stack(minibatch[:, 0])
                next_state = np.stack(minibatch[:, 1])
                previous_action = np.stack(minibatch[:, 2])
                action = np.stack(minibatch[:, 3])
                reward = np.stack(minibatch[:, 4])
                done = np.stack(minibatch[:, 5])

                loss, td_error = learner.distributed_train(
                    state=state,
                    next_state=next_state,
                    previous_action=previous_action,
                    action=action,
                    reward=reward,
                    done=done,
                    is_weight=is_weight)

                writer.add_scalar('data/loss', loss, train_step)
                writer.add_scalar('data/time', time.time() - s, train_step)

                if train_step % 100 == 0:
                    learner.target_to_main()

                for i in range(len(idxs)):
                    replay_buffer.update(idxs[i], td_error[i])

    else:
        env = wrappers.make_uint8_env(data['env'][FLAGS.task])
        local_buffer = buffer_queue.LocalBuffer(capacity=int(1e4))

        epsilon = 1.0
        train_step = 0
        episode = 0
        episode_step = 0
        total_max_prob = 0

        state = env.reset()
        previous_action = 0
        score = 0
        lives = 5

        writer = SummaryWriter('runs/{}/actor_{}'.format(
            data['env'][FLAGS.task], FLAGS.task))

        while True:

            actor.parameter_sync()

            for _ in range(data['trajectory']):

                action, q_value, max_q_value = actor.get_policy_and_action(
                    state=state,
                    previous_action=previous_action,
                    epsilon=epsilon)

                episode_step += 1
                total_max_prob += max_q_value

                next_state, reward, done, info = env.step(
                    action % data['available_action'][FLAGS.task])

                score += reward

                if lives != info['ale.lives']:
                    r = -1
                    d = True
                else:
                    r = reward
                    d = False

                local_buffer.append(state=state,
                                    done=d,
                                    reward=r,
                                    next_state=next_state,
                                    previous_action=previous_action,
                                    action=action)

                state = next_state
                previous_action = action
                lives = info['ale.lives']

                if len(local_buffer) > 3 * data['trajectory']:
                    train_step += 1
                    sampled_data = local_buffer.sample(data['trajectory'])
                    apex_queue.append_to_queue(
                        task=FLAGS.task,
                        unrolled_state=sampled_data['state'],
                        unrolled_next_state=sampled_data['next_state'],
                        unrolled_previous_action=sampled_data[
                            'previous_action'],
                        unrolled_action=sampled_data['action'],
                        unrolled_reward=sampled_data['reward'],
                        unrolled_done=sampled_data['done'])

                if done:
                    print(episode, score, epsilon)
                    writer.add_scalar('data/epsilon', epsilon, episode)
                    writer.add_scalar('data/episode_step', episode_step,
                                      episode)
                    writer.add_scalar('data/score', score, episode)
                    writer.add_scalar('data/total_max_prob',
                                      total_max_prob / episode_step, episode)
                    episode_step = 0
                    episode += 1
                    score = 0
                    lives = 5
                    epsilon = 1 / (episode * 0.05 + 1)
                    state = env.reset()
                    previous_action = 0
def main(_):

    local_job_device = '/job:{}/task:{}'.format(FLAGS.job_name, FLAGS.task)
    shared_job_device = '/job:learner/task:0'
    is_actor_fn = lambda i: FLAGS.job_name == 'actor' and i == FLAGS.task
    is_learner = FLAGS.job_name == 'learner'

    cluster = tf.train.ClusterSpec({
        'actor':
        ['localhost:{}'.format(8001 + i) for i in range(FLAGS.num_actors)],
        'learner': ['localhost:8000']
    })

    server = tf.train.Server(cluster,
                             job_name=FLAGS.job_name,
                             task_index=FLAGS.task)

    filters = [shared_job_device, local_job_device]

    output_size = 18
    available_output_size = 6
    env_name = 'DemonAttackDeterministic-v4'
    input_shape = [84, 84, 4]

    with tf.device(shared_job_device):

        with tf.device('/cpu'):
            queue = buffer_queue.FIFOQueue(FLAGS.trajectory, input_shape,
                                           output_size, FLAGS.queue_size,
                                           FLAGS.batch_size, FLAGS.num_actors,
                                           FLAGS.lstm_size)

        learner = model.IMPALA(trajectory=FLAGS.trajectory,
                               input_shape=input_shape,
                               num_action=output_size,
                               discount_factor=FLAGS.discount_factor,
                               start_learning_rate=FLAGS.start_learning_rate,
                               end_learning_rate=FLAGS.end_learning_rate,
                               learning_frame=FLAGS.learning_frame,
                               baseline_loss_coef=FLAGS.baseline_loss_coef,
                               entropy_coef=FLAGS.entropy_coef,
                               gradient_clip_norm=FLAGS.gradient_clip_norm,
                               reward_clipping=FLAGS.reward_clipping,
                               model_name='learner',
                               learner_name='learner',
                               lstm_hidden_size=FLAGS.lstm_size)

    with tf.device(local_job_device):
        actor = model.IMPALA(trajectory=FLAGS.trajectory,
                             input_shape=input_shape,
                             num_action=output_size,
                             discount_factor=FLAGS.discount_factor,
                             start_learning_rate=FLAGS.start_learning_rate,
                             end_learning_rate=FLAGS.end_learning_rate,
                             learning_frame=FLAGS.learning_frame,
                             baseline_loss_coef=FLAGS.baseline_loss_coef,
                             entropy_coef=FLAGS.entropy_coef,
                             gradient_clip_norm=FLAGS.gradient_clip_norm,
                             reward_clipping=FLAGS.reward_clipping,
                             model_name='actor_{}'.format(FLAGS.task),
                             learner_name='learner',
                             lstm_hidden_size=FLAGS.lstm_size)

    sess = tf.Session(server.target)
    queue.set_session(sess)
    learner.set_session(sess)

    if not is_learner:
        actor.set_session(sess)

    if is_learner:

        writer = tensorboardX.SummaryWriter('runs/learner')
        train_step = 0

        while True:
            size = queue.get_size()
            if size > 3 * FLAGS.batch_size:
                train_step += 1
                batch = queue.sample_batch()
                s = time.time()
                pi_loss, baseline_loss, entropy, learning_rate = learner.train(
                    state=np.stack(batch.state),
                    reward=np.stack(batch.reward),
                    action=np.stack(batch.action),
                    done=np.stack(batch.done),
                    behavior_policy=np.stack(batch.behavior_policy),
                    previous_action=np.stack(batch.previous_action),
                    initial_h=np.stack(batch.previous_h),
                    initial_c=np.stack(batch.previous_c))
                writer.add_scalar('data/pi_loss', pi_loss, train_step)
                writer.add_scalar('data/baseline_loss', baseline_loss,
                                  train_step)
                writer.add_scalar('data/entropy', entropy, train_step)
                writer.add_scalar('data/learning_rate', learning_rate,
                                  train_step)
                writer.add_scalar('data/time', time.time() - s, train_step)
    else:

        trajectory_data = collections.namedtuple('trajectory_data', [
            'state', 'next_state', 'reward', 'done', 'action',
            'behavior_policy', 'previous_action', 'initial_h', 'initial_c'
        ])

        env = wrappers.make_uint8_env(env_name)
        if FLAGS.task == 0:
            env = gym.wrappers.Monitor(
                env,
                'save-mov',
                video_callable=lambda episode_id: episode_id % 10 == 0)
        state = env.reset()
        previous_action = 0
        previous_h = np.zeros([FLAGS.lstm_size])
        previous_c = np.zeros([FLAGS.lstm_size])

        episode = 0
        score = 0
        episode_step = 0
        total_max_prob = 0
        lives = 4

        writer = tensorboardX.SummaryWriter('runs/{}/actor_{}'.format(
            env_name, FLAGS.task))

        while True:

            unroll_data = trajectory_data([], [], [], [], [], [], [], [], [])

            actor.parameter_sync()

            for _ in range(FLAGS.trajectory):

                action, behavior_policy, max_prob, h, c = actor.get_policy_and_action(
                    state, previous_action, previous_h, previous_c)

                episode_step += 1
                total_max_prob += max_prob

                next_state, reward, done, info = env.step(
                    action % available_output_size)

                score += reward

                if lives != info['ale.lives']:
                    r = -1
                    d = True
                else:
                    r = reward
                    d = False

                unroll_data.state.append(state)
                unroll_data.next_state.append(next_state)
                unroll_data.reward.append(r)
                unroll_data.done.append(d)
                unroll_data.action.append(action)
                unroll_data.behavior_policy.append(behavior_policy)
                unroll_data.previous_action.append(previous_action)
                unroll_data.initial_h.append(previous_h)
                unroll_data.initial_c.append(previous_c)

                state = next_state
                previous_action = action
                previous_h = h
                previous_c = c
                lives = info['ale.lives']

                if done:

                    print(episode, score)
                    writer.add_scalar('data/{}/prob'.format(env_name),
                                      total_max_prob / episode_step, episode)
                    writer.add_scalar('data/{}/score'.format(env_name), score,
                                      episode)
                    writer.add_scalar('data/{}/episode_step'.format(env_name),
                                      episode_step, episode)
                    episode += 1
                    score = 0
                    episode_step = 0
                    total_max_prob = 0
                    lives = 4
                    state = env.reset()
                    previous_action = 0
                    previous_h = np.zeros([FLAGS.lstm_size])
                    previous_c = np.zeros([FLAGS.lstm_size])

            queue.append_to_queue(
                task=FLAGS.task,
                unrolled_state=unroll_data.state,
                unrolled_next_state=unroll_data.next_state,
                unrolled_reward=unroll_data.reward,
                unrolled_done=unroll_data.done,
                unrolled_action=unroll_data.action,
                unrolled_behavior_policy=unroll_data.behavior_policy,
                unrolled_previous_action=unroll_data.previous_action,
                unrolled_previous_h=unroll_data.initial_h,
                unrolled_previous_c=unroll_data.initial_c)
def main(_):
    data = json.load(open('config.json'))
    data = data['impala']
    utils.check_properties(data)

    local_job_device = f'/job:{FLAGS.job_name}/task:{FLAGS.task}'
    shared_job_device = '/job:learner/task:0'
    is_learner = FLAGS.job_name == 'learner'

    cluster = tf.train.ClusterSpec({
        'actor': [
            'localhost:{}'.format(data['server_port'] + 1 + i)
            for i in range(data['num_actors'])
        ],
        'learner': ['{}:{}'.format(data['server_ip'], data['server_port'])]
    })

    server = tf.train.Server(cluster,
                             job_name=FLAGS.job_name,
                             task_index=FLAGS.task)

    with tf.device(shared_job_device):
        with tf.device('/cpu'):
            queue = buffer_queue.FIFOQueue(trajectory=data['trajectory'],
                                           input_shape=data['model_input'],
                                           output_size=data['model_output'],
                                           queue_size=data['queue_size'],
                                           batch_size=data['batch_size'],
                                           num_actors=data['num_actors'],
                                           lstm_size=data['lstm_size'])

        learner = impala.Agent(trajectory=data['trajectory'],
                               input_shape=data['model_input'],
                               num_action=data['model_output'],
                               lstm_hidden_size=data['lstm_size'],
                               discount_factor=data['discount_factor'],
                               start_learning_rate=data['start_learning_rate'],
                               end_learning_rate=data['end_learning_rate'],
                               learning_frame=data['learning_frame'],
                               baseline_loss_coef=data['baseline_loss_coef'],
                               entropy_coef=data['entropy_coef'],
                               gradient_clip_norm=data['gradient_clip_norm'],
                               reward_clipping=data['reward_clipping'],
                               model_name='learner',
                               learner_name='learner')

    with tf.device(local_job_device):

        actor = impala.Agent(trajectory=data['trajectory'],
                             input_shape=data['model_input'],
                             num_action=data['model_output'],
                             lstm_hidden_size=data['lstm_size'],
                             discount_factor=data['discount_factor'],
                             start_learning_rate=data['start_learning_rate'],
                             end_learning_rate=data['end_learning_rate'],
                             learning_frame=data['learning_frame'],
                             baseline_loss_coef=data['baseline_loss_coef'],
                             entropy_coef=data['entropy_coef'],
                             gradient_clip_norm=data['gradient_clip_norm'],
                             reward_clipping=data['reward_clipping'],
                             model_name=f'actor_{FLAGS.task}',
                             learner_name='learner')

    sess = tf.Session(server.target)
    queue.set_session(sess)
    learner.set_session(sess)

    if not is_learner:
        actor.set_session(sess)

    if is_learner:

        writer = SummaryWriter('runs/learner')
        train_step = 0
        while True:
            size = queue.get_size()
            if size > 3 * data['batch_size']:
                print('train {}'.format(train_step))
                train_step += 1
                batch = queue.sample_batch()
                s = time.time()
                pi_loss, baseline_loss, entropy, learning_rate = learner.train(
                    state=np.stack(batch.state),
                    reward=np.stack(batch.reward),
                    action=np.stack(batch.action),
                    done=np.stack(batch.done),
                    behavior_policy=np.stack(batch.behavior_policy),
                    previous_action=np.stack(batch.previous_action),
                    initial_h=np.stack(batch.previous_h),
                    initial_c=np.stack(batch.previous_c))
                writer.add_scalar('data/pi_loss', pi_loss, train_step)
                writer.add_scalar('data/baseline_loss', baseline_loss,
                                  train_step)
                writer.add_scalar('data/entropy', entropy, train_step)
                writer.add_scalar('data/learning_rate', learning_rate,
                                  train_step)
                writer.add_scalar('data/time', time.time() - s, train_step)

    else:

        trajectory = utils.UnrolledTrajectory()
        env = wrappers.make_uint8_env(data['env'][FLAGS.task])
        state = env.reset()
        previous_action = 0
        previous_h = np.zeros([data['lstm_size']])
        previous_c = np.zeros([data['lstm_size']])

        episode = 0
        score = 0
        episode_step = 0
        total_max_prob = 0
        lives = 5

        writer = SummaryWriter('runs/{}/actor_{}'.format(
            data['env'][FLAGS.task], FLAGS.task))

        while True:

            trajectory.initialize()
            actor.parameter_sync()

            for _ in range(data['trajectory']):

                action, behavior_policy, max_prob, h, c = actor.get_policy_and_action(
                    state, previous_action, previous_h, previous_c)

                episode_step += 1
                total_max_prob += max_prob

                next_state, reward, done, info = env.step(
                    action % data['available_action'][FLAGS.task])

                score += reward

                if lives != info['ale.lives']:
                    r = -1
                    d = True
                else:
                    r = reward
                    d = False

                trajectory.append(state=state,
                                  next_state=next_state,
                                  reward=r,
                                  done=d,
                                  action=action,
                                  behavior_policy=behavior_policy,
                                  previous_action=previous_action,
                                  initial_h=previous_h,
                                  initial_c=previous_c)

                state = next_state
                previous_action = action
                previous_h = h
                previous_c = c
                lives = info['ale.lives']

                if done:
                    print(episode, score)
                    writer.add_scalar(
                        'data/{}/prob'.format(data['env'][FLAGS.task]),
                        total_max_prob / episode_step, episode)
                    writer.add_scalar(
                        'data/{}/score'.format(data['env'][FLAGS.task]), score,
                        episode)
                    writer.add_scalar(
                        'data/{}/episode_step'.format(data['env'][FLAGS.task]),
                        episode_step, episode)
                    episode += 1
                    score = 0
                    episode_step = 0
                    total_max_prob = 0
                    state = env.reset()
                    previous_action = 0
                    previous_h = np.zeros([data['lstm_size']])
                    previous_c = np.zeros([data['lstm_size']])
                    lives = 5

            unrolled_data = trajectory.extract()
            queue.append_to_queue(
                task=FLAGS.task,
                unrolled_state=unrolled_data['state'],
                unrolled_next_state=unrolled_data['next_state'],
                unrolled_reward=unrolled_data['reward'],
                unrolled_done=unrolled_data['done'],
                unrolled_behavior_policy=unrolled_data['behavior_policy'],
                unrolled_action=unrolled_data['action'],
                unrolled_previous_action=unrolled_data['previous_action'],
                unrolled_previous_h=unrolled_data['initial_h'],
                unrolled_previous_c=unrolled_data['initial_c'])
def main(_):

    data = json.load(open('config.json'))
    data = data['a3c']
    utils.check_properties(data)

    local_job_device = f'/job:{FLAGS.job_name}/task:{FLAGS.task}'
    shared_job_device = '/job:learner/task:0'
    is_learner = FLAGS.job_name == 'learner'

    cluster = tf.train.ClusterSpec({
        'actor': [
            'localhost:{}'.format(data['server_port'] + 1 + i)
            for i in range(data['num_actors'])
        ],
        'learner': ['{}:{}'.format(data['server_ip'], data['server_port'])]
    })

    server = tf.train.Server(cluster,
                             job_name=FLAGS.job_name,
                             task_index=FLAGS.task)

    with tf.device(shared_job_device):

        with tf.device('/cpu'):
            queue = buffer_queue.A3CFIFOQueue(
                trajectory_size=data['trajectory'],
                input_shape=data['model_input'],
                output_size=data['model_output'],
                num_actors=data['num_actors'])

        learner = a3c.Agent(input_shape=data['model_input'],
                            num_action=data['model_output'],
                            discount_factor=data['discount_factor'],
                            start_learning_rate=data['start_learning_rate'],
                            end_learning_rate=data['end_learning_rate'],
                            learning_frame=data['learning_frame'],
                            baseline_loss_coef=data['baseline_loss_coef'],
                            entropy_coef=data['entropy_coef'],
                            gradient_clip_norm=data['gradient_clip_norm'],
                            reward_clipping=data['reward_clipping'],
                            model_name='learner',
                            learner_name='learner')

    with tf.device(local_job_device):

        actor = a3c.Agent(input_shape=data['model_input'],
                          num_action=data['model_output'],
                          discount_factor=data['discount_factor'],
                          start_learning_rate=data['start_learning_rate'],
                          end_learning_rate=data['end_learning_rate'],
                          learning_frame=data['learning_frame'],
                          baseline_loss_coef=data['baseline_loss_coef'],
                          entropy_coef=data['entropy_coef'],
                          gradient_clip_norm=data['gradient_clip_norm'],
                          reward_clipping=data['reward_clipping'],
                          model_name=f'actor_{FLAGS.task}',
                          learner_name='learner')

    sess = tf.Session(server.target)
    learner.set_session(sess)
    queue.set_session(sess)
    if not is_learner:
        actor.set_session(sess)

    if is_learner:

        writer = SummaryWriter('runs/learner')
        train_step = 0
        while True:

            if queue.get_size:
                train_step += 1
                batch = queue.sample_batch()
                pi_loss, value_loss, entropy, learning_rate = learner.train(
                    state=batch.state[0],
                    next_state=batch.next_state[0],
                    previous_action=batch.previous_action[0],
                    action=batch.action[0],
                    reward=batch.reward[0],
                    done=batch.done[0])

                writer.add_scalar('data/pi_loss', pi_loss, train_step)
                writer.add_scalar('data/value_loss', value_loss, train_step)
                writer.add_scalar('data/entropy', entropy, train_step)
                writer.add_scalar('data/lr', learning_rate, train_step)

                print('#########')
                print(f'pi loss    : {pi_loss}')
                print(f'value loss : {value_loss}')
                print(f'entropy    : {entropy}')
                print(f'lr         : {learning_rate}')
                print(f'step       : {train_step}')

    else:

        writer = SummaryWriter('runs/{}/actor_{}'.format(
            data['env'][FLAGS.task], FLAGS.task))

        trajectory = utils.UnrolledA3CTrajectory()
        env = wrappers.make_uint8_env(data['env'][FLAGS.task])
        state = env.reset()
        previous_action = 0

        episode = 0
        score = 0
        episode_step = 0
        total_max_prob = 0
        lives = 5

        while True:

            trajectory.initialize()
            actor.parameter_sync()

            for _ in range(data['trajectory']):

                action, policy, max_prob = actor.get_policy_and_action(
                    state=state, previous_action=previous_action)

                episode_step += 1
                total_max_prob += max_prob

                next_state, reward, done, info = env.step(
                    action % data['available_action'][FLAGS.task])

                score += reward

                if lives != info['ale.lives']:
                    r = -1
                    d = True
                else:
                    r = reward
                    d = False

                trajectory.append(state=state,
                                  next_state=next_state,
                                  previous_action=previous_action,
                                  action=action,
                                  reward=r,
                                  done=d)

                state = next_state
                previous_action = action
                lives = info['ale.lives']

                if done:
                    print(score, episode)
                    writer.add_scalar(
                        'data/{}/prob'.format(data['env'][FLAGS.task]),
                        total_max_prob / episode_step, episode)
                    writer.add_scalar(
                        'data/{}/score'.format(data['env'][FLAGS.task]), score,
                        episode)
                    writer.add_scalar(
                        'data/{}/episode_step'.format(data['env'][FLAGS.task]),
                        episode_step, episode)
                    episode += 1
                    score = 0
                    episode_step = 0
                    total_max_prob = 0
                    state = env.reset()
                    previous_action = 0
                    lives = 5

            unrolled_data = trajectory.extract()
            queue.append_to_queue(
                task=FLAGS.task,
                unrolled_state=unrolled_data['state'],
                unrolled_next_state=unrolled_data['next_state'],
                unrolled_previous_action=unrolled_data['previous_action'],
                unrolled_action=unrolled_data['action'],
                unrolled_reward=unrolled_data['reward'],
                unrolled_done=unrolled_data['done'])
Exemple #5
0
def main(_):

    local_job_device = '/job:{}/task:{}'.format(FLAGS.job_name, FLAGS.task)
    shared_job_device = '/job:learner/task:0'
    is_actor_fn = lambda i: FLAGS.job_name == 'actor' and i == FLAGS.task
    is_learner = FLAGS.job_name == 'learner'

    cluster = tf.train.ClusterSpec({
        'actor':
        ['localhost:{}'.format(8001 + i) for i in range(FLAGS.num_actors)],
        'learner': ['localhost:8000']
    })

    server = tf.train.Server(cluster,
                             job_name=FLAGS.job_name,
                             task_index=FLAGS.task)

    filters = [shared_job_device, local_job_device]

    input_shape = [84, 84, 4]
    output_size = 6
    env_name = 'PongDeterministic-v4'

    with tf.device(shared_job_device):
        queue = buffer_queue.FIFOQueue(FLAGS.trajectory, input_shape,
                                       output_size, FLAGS.queue_size,
                                       FLAGS.batch_size, FLAGS.num_actors)
        learner = model.IMPALA(trajectory=FLAGS.trajectory,
                               input_shape=input_shape,
                               num_action=output_size,
                               discount_factor=FLAGS.discount_factor,
                               start_learning_rate=FLAGS.start_learning_rate,
                               end_learning_rate=FLAGS.end_learning_rate,
                               learning_frame=FLAGS.learning_frame,
                               baseline_loss_coef=FLAGS.baseline_loss_coef,
                               entropy_coef=FLAGS.entropy_coef,
                               gradient_clip_norm=FLAGS.gradient_clip_norm)

    sess = tf.Session(server.target)
    queue.set_session(sess)
    learner.set_session(sess)

    if is_learner:

        writer = tensorboardX.SummaryWriter('runs/learner')
        train_step = 0

        while True:
            size = queue.get_size()
            if size > 3 * FLAGS.batch_size:
                train_step += 1
                batch = queue.sample_batch()
                s = time.time()
                pi_loss, baseline_loss, entropy, learning_rate = learner.train(
                    state=np.stack(batch.state),
                    reward=np.stack(batch.reward),
                    action=np.stack(batch.action),
                    done=np.stack(batch.done),
                    behavior_policy=np.stack(batch.behavior_policy))
                writer.add_scalar('data/pi_loss', pi_loss, train_step)
                writer.add_scalar('data/baseline_loss', baseline_loss,
                                  train_step)
                writer.add_scalar('data/entropy', entropy, train_step)
                writer.add_scalar('data/learning_rate', learning_rate,
                                  train_step)
                writer.add_scalar('data/time', time.time() - s, train_step)
    else:

        trajectory_data = collections.namedtuple('trajectory_data', [
            'state', 'next_state', 'reward', 'done', 'action',
            'behavior_policy'
        ])

        env = wrappers.make_uint8_env(env_name)
        if FLAGS.task == 0:
            env = gym.wrappers.Monitor(
                env,
                'save-mov',
                video_callable=lambda episode_id: episode_id % 10 == 0)
        state = env.reset()

        episode = 0
        score = 0
        episode_step = 0
        total_max_prob = 0

        writer = tensorboardX.SummaryWriter('runs/actor_{}'.format(FLAGS.task))

        while True:

            unroll_data = trajectory_data([], [], [], [], [], [])

            for _ in range(FLAGS.trajectory):

                action, behavior_policy, max_prob = learner.get_policy_and_action(
                    state)

                episode_step += 1
                total_max_prob += max_prob

                next_state, reward, done, info = env.step(action)

                score += reward

                d = False
                if reward == -1:
                    d = True

                unroll_data.state.append(state)
                unroll_data.next_state.append(next_state)
                unroll_data.reward.append(reward)
                unroll_data.done.append(d)
                unroll_data.action.append(action)
                unroll_data.behavior_policy.append(behavior_policy)

                state = next_state

                if done:

                    print(episode, score)
                    writer.add_scalar('data/prob',
                                      total_max_prob / episode_step, episode)
                    writer.add_scalar('data/score', score, episode)
                    writer.add_scalar('data/episode_step', episode_step,
                                      episode)
                    episode += 1
                    score = 0
                    episode_step = 0
                    total_max_prob = 0
                    lives = 5
                    state = env.reset()

            queue.append_to_queue(
                task=FLAGS.task,
                unrolled_state=unroll_data.state,
                unrolled_next_state=unroll_data.next_state,
                unrolled_reward=unroll_data.reward,
                unrolled_done=unroll_data.done,
                unrolled_action=unroll_data.action,
                unrolled_behavior_policy=unroll_data.behavior_policy)