state_shape = env.observation_space.shape
 __agent = DQN(
     f_create_q=f_net, state_shape=state_shape,
     # OneStepTD arguments
     num_actions=num_actions, discount_factor=gamma, ddqn=if_ddqn,
     # target network sync arguments
     target_sync_interval=target_sync_interval,
     target_sync_rate=target_sync_rate,
     # epsilon greedy arguments
     greedy_epsilon=greedy_epsilon,
     # optimizer arguments
     network_optimizer=LocalOptimizer(optimizer_td, max_grad_norm),
     # sampler arguments
     sampler=TransitionSampler(
         replay_buffer,
         batch_size=batch_size,
         interval=update_interval,
         minimum_count=sample_mimimum_count),
     # checkpoint
     global_step=global_step
  )
 # Utilities
 stepsSaver = StepsSaver(our_log_dir)
 reward_vector2scalar = FuncReward(gamma)
 # Configure sess
 config = tf.ConfigProto()
 config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_fraction
 with __agent.create_session(
         config=config, save_dir=tf_log_dir,
         save_checkpoint_secs=save_checkpoint_secs) as sess, \
     AsynchronousAgent(
Exemple #2
0
    state_shape=state_shape,
    # OneStepTD arguments
    num_actions=len(ACTIONS),
    discount_factor=0.9,
    ddqn=False,
    # target network sync arguments
    target_sync_interval=1,
    target_sync_rate=target_sync_rate,
    # epsilon greeedy arguments
    greedy_epsilon=0.2,
    # optimizer arguments
    network_optimizer=hrl.network.LocalOptimizer(optimizer_td, 10.0),
    # max_gradient=10.0,
    # sampler arguments
    sampler=TransitionSampler(BalancedMapPlayback(num_actions=len(ACTIONS),
                                                  capacity=15000),
                              batch_size=8,
                              interval=1),
    # checkpoint
    global_step=global_step)


def log_info(update_info):
    global action_fraction
    global action_td_loss
    global agent
    global next_state
    global ACTIONS
    global n_steps
    global done
    global cum_td_loss
    global cum_reward