def __init__(self, args):

        super(BasePGQLearner, self).__init__(args)

        self.q_update_counter = 0
        self.replay_size = args.replay_size
        self.pgq_fraction = args.pgq_fraction
        self.batch_update_size = args.batch_update_size
        scope_name = 'local_learning_{}'.format(self.actor_id)
        conf_learning = {'name': scope_name,
                         'input_shape': self.input_shape,
                         'num_act': self.num_actions,
                         'args': args}

        with tf.device('/cpu:0'):
            self.local_network = PolicyValueNetwork(conf_learning)
        with tf.device('/gpu:0'), tf.variable_scope('', reuse=True):
            self.batch_network = PolicyValueNetwork(conf_learning)
            self._build_q_ops()

        self.reset_hidden_state()
        self.replay_memory = ReplayMemory(
            self.replay_size,
            self.local_network.get_input_shape(),
            self.num_actions)
            
        if self.is_master():
            var_list = self.local_network.params
            self.saver = tf.train.Saver(var_list=var_list, max_to_keep=3, 
                                        keep_checkpoint_every_n_hours=2)
Example #2
0
    def __init__(self, args):
        args.entropy_regularisation_strength = 0.0
        super(TRPOLearner, self).__init__(args)

        policy_conf = {
            'name': 'local_learning_{}'.format(self.actor_id),
            'input_shape': self.input_shape,
            'num_act': self.num_actions,
            'args': args
        }

        #we use separate networks as in the paper since so we don't do damage to the trust region updates
        self.policy_network = PolicyValueNetwork(policy_conf,
                                                 use_value_head=False)
        self.local_network = self.policy_network
        #self.value_network = PolicyValueNetwork(value_conf, use_policy_head=False)

        if self.actor_id == 0:
            var_list = self.policy_network.params  #+self.value_network.params
            self.saver = tf.train.Saver(var_list=var_list,
                                        max_to_keep=3,
                                        keep_checkpoint_every_n_hours=2)

        self.batch_size = 512
        self.cg_damping = 0.001
        self.max_kl = args.max_kl
        self.max_rollout = args.max_rollout
        self.episodes_per_batch = args.trpo_episodes
        self._build_ops()
Example #3
0
    def __init__(self, args):
        args.entropy_regularisation_strength = 0.0
        super(TRPOLearner, self).__init__(args)

        self.batch_size = 512
        self.max_cg_iters = 20
        self.num_epochs = args.num_epochs
        self.cg_damping = args.cg_damping
        self.cg_subsample = args.cg_subsample
        self.max_kl = args.max_kl
        self.max_rollout = args.max_rollout
        self.episodes_per_batch = args.episodes_per_batch
        self.baseline_vars = args.baseline_vars
        self.experience_queue = args.experience_queue
        self.task_queue = args.task_queue

        policy_conf = {
            'name': 'policy_network_{}'.format(self.actor_id),
            'input_shape': self.input_shape,
            'num_act': self.num_actions,
            'args': args
        }
        value_conf = policy_conf.copy()
        value_conf['name'] = 'value_network_{}'.format(self.actor_id)

        self.device = '/gpu:0' if self.is_master() else '/cpu:0'
        with tf.device(self.device):
            #we use separate networks as in the paper since so we don't do damage to the trust region updates
            self.policy_network = PolicyValueNetwork(policy_conf,
                                                     use_value_head=False)
            self.value_network = PolicyValueNetwork(value_conf,
                                                    use_policy_head=False)
            self.local_network = self.policy_network
            self._build_ops()

        if self.is_master():
            var_list = self.policy_network.params + self.value_network.params
            self.saver = tf.train.Saver(var_list=var_list,
                                        max_to_keep=3,
                                        keep_checkpoint_every_n_hours=2)
    def __init__(self, args):
        super(A3CLearner, self).__init__(args)

        conf_learning = {
            'name': 'local_learning_{}'.format(self.actor_id),
            'input_shape': self.input_shape,
            'num_act': self.num_actions,
            'args': args
        }

        self.local_network = PolicyValueNetwork(conf_learning)
        self.reset_hidden_state()

        if self.actor_id == 0:
            var_list = self.local_network.params
            self.saver = tf.train.Saver(var_list=var_list,
                                        max_to_keep=3,
                                        keep_checkpoint_every_n_hours=2)
class BasePGQLearner(BaseA3CLearner):
    def __init__(self, args):

        super(BasePGQLearner, self).__init__(args)

        self.q_update_counter = 0
        self.replay_size = args.replay_size
        self.pgq_fraction = args.pgq_fraction
        self.batch_update_size = args.batch_update_size
        scope_name = 'local_learning_{}'.format(self.actor_id)
        conf_learning = {'name': scope_name,
                         'input_shape': self.input_shape,
                         'num_act': self.num_actions,
                         'args': args}

        with tf.device('/cpu:0'):
            self.local_network = PolicyValueNetwork(conf_learning)
        with tf.device('/gpu:0'), tf.variable_scope('', reuse=True):
            self.batch_network = PolicyValueNetwork(conf_learning)
            self._build_q_ops()

        self.reset_hidden_state()
        self.replay_memory = ReplayMemory(
            self.replay_size,
            self.local_network.get_input_shape(),
            self.num_actions)
            
        if self.is_master():
            var_list = self.local_network.params
            self.saver = tf.train.Saver(var_list=var_list, max_to_keep=3, 
                                        keep_checkpoint_every_n_hours=2)


    def _build_q_ops(self):
        # pgq specific initialization
        self.pgq_fraction = self.pgq_fraction
        self.batch_size = self.batch_update_size
        self.q_tilde = self.batch_network.beta * (
            self.batch_network.log_output_layer_pi
            + tf.expand_dims(self.batch_network.output_layer_entropy, 1)
        ) + self.batch_network.output_layer_v

        self.Qi, self.Qi_plus_1 = tf.split(axis=0, num_or_size_splits=2, value=self.q_tilde)
        self.V, _ = tf.split(axis=0, num_or_size_splits=2, value=self.batch_network.output_layer_v)
        self.log_pi, _ = tf.split(axis=0, num_or_size_splits=2, value=tf.expand_dims(self.batch_network.log_output_selected_action, 1))
        self.R = tf.placeholder('float32', [None], name='1-step_reward')

        self.terminal_indicator = tf.placeholder(tf.float32, [None], name='terminal_indicator')
        self.max_TQ = self.gamma*tf.reduce_max(self.Qi_plus_1, 1) * (1 - self.terminal_indicator)
        self.Q_a = tf.reduce_sum(self.Qi * tf.split(axis=0, num_or_size_splits=2, value=self.batch_network.selected_action_ph)[0], 1)

        self.q_objective = - self.pgq_fraction * tf.reduce_mean(tf.stop_gradient(self.R + self.max_TQ - self.Q_a) * (0.5 * self.V[:, 0] + self.log_pi[:, 0]))

        self.V_params = self.batch_network.params
        self.q_gradients = tf.gradients(self.q_objective, self.V_params)
        self.q_gradients = self.batch_network._clip_grads(self.q_gradients)


    def batch_q_update(self):
        if len(self.replay_memory) < self.replay_memory.maxlen//10:
            return

        s_i, a_i, r_i, s_f, is_terminal = self.replay_memory.sample_batch(self.batch_size)

        batch_grads = self.session.run(
            self.q_gradients,
            feed_dict={
                self.R: r_i,
                self.batch_network.selected_action_ph: np.vstack([a_i, a_i]),
                self.batch_network.input_ph: np.vstack([s_i, s_f]),
                self.terminal_indicator: is_terminal.astype(np.int),
            }
        )
        self.apply_gradients_to_shared_memory_vars(batch_grads)
Example #6
0
def main(args):
    args.batch_size = None
    logger.debug('CONFIGURATION: {}'.format(args))
    """ Set up the graph, the agents, and run the agents in parallel. """
    if args.env == 'GYM':
        from environments import atari_environment
        num_actions, action_space, _ = atari_environment.get_actions(args.game)
        input_shape = atari_environment.get_input_shape(args.game)
    elif args.env == 'DOOM':
        from environments.vizdoom_env import VizDoomEnv
        env = VizDoomEnv(args.doom_cfg, args.game, args.is_train)
        num_actions, action_space = env.get_actions()
        input_shape = env.get_input_shape()
    else:
        num_actions = get_num_actions(args.rom_path, args.game)

    args.action_space = action_space
    args.summ_base_dir = '/tmp/summary_logs/{}/{}'.format(
        args.game, time.strftime('%m.%d/%H.%M'))
    logger.info('logging summaries to {}'.format(args.summ_base_dir))

    Learner, Network = ALGORITHMS[args.alg_type]
    network = Network({
        'name': 'shared_vars_network',
        'input_shape': input_shape,
        'num_act': num_actions,
        'args': args
    })
    args.network = Network

    #initialize shared variables
    args.learning_vars = SharedVars(network.params)
    args.opt_state = SharedVars(
        network.params, opt_type=args.opt_type,
        lr=args.initial_lr) if args.opt_mode == 'shared' else None
    args.batch_opt_state = SharedVars(
        network.params, opt_type=args.opt_type,
        lr=args.initial_lr) if args.opt_mode == 'shared' else None

    #TODO: need to refactor so TRPO+GAE doesn't need special treatment
    if args.alg_type in ['trpo', 'trpo-continuous']:
        if args.arch == 'FC':  #add timestep feature
            vf_input_shape = [input_shape[0] + 1]
        else:
            vf_input_shape = input_shape

        baseline_network = PolicyValueNetwork(
            {
                'name': 'shared_value_network',
                'input_shape': vf_input_shape,
                'num_act': num_actions,
                'args': args
            },
            use_policy_head=False)
        args.baseline_vars = SharedVars(baseline_network.params)
        args.vf_input_shape = vf_input_shape

    if args.alg_type in ['q', 'sarsa', 'dueling', 'dqn-cts']:
        args.target_vars = SharedVars(network.params)
        args.target_update_flags = SharedFlags(args.num_actor_learners)
    if args.alg_type == 'dqn-cts':
        args.density_model_update_flags = SharedFlags(args.num_actor_learners)

    tf.reset_default_graph()
    args.barrier = Barrier(args.num_actor_learners)
    args.global_step = SharedCounter(0)
    args.num_actions = num_actions

    cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES')
    num_gpus = 0
    if cuda_visible_devices:
        num_gpus = len(cuda_visible_devices.split())

    #spin up processes and block
    if (args.visualize == 2): args.visualize = 0
    actor_learners = []
    task_queue = Queue()
    experience_queue = Queue()
    seed = args.seed or np.random.randint(2**32)
    np.random.seed(seed)
    tf.set_random_seed(seed)
    for i in xrange(args.num_actor_learners):
        if (args.visualize == 2) and (i == args.num_actor_learners - 1):
            args.args.visualize = 1

        args.actor_id = i
        args.device = '/gpu:{}'.format(i % num_gpus) if num_gpus else '/cpu:0'

        args.random_seed = seed + i

        #only used by TRPO
        args.task_queue = task_queue
        args.experience_queue = experience_queue

        args.input_shape = input_shape
        actor_learners.append(Learner(args))
        actor_learners[-1].start()

    try:
        for t in actor_learners:
            t.join()
    except KeyboardInterrupt:
        #Terminate with extreme prejudice
        for t in actor_learners:
            t.terminate()

    logger.info('All training threads finished!')
    logger.info('Use seed={} to reproduce'.format(seed))
Example #7
0
def main(args):
    args.batch_size = None
    logger.debug('CONFIGURATION: {}'.format(args))

    """ Set up the graph, the agents, and run the agents in parallel. """
    if args.env == 'GYM':
        from environments import atari_environment
        num_actions, action_space, _ = atari_environment.get_actions(args.game)
        input_shape = atari_environment.get_input_shape(args.game)
    else:
        num_actions = get_num_actions(args.rom_path, args.game)

    args.action_space = action_space
    args.summ_base_dir = '/tmp/summary_logs/{}/{}'.format(args.game, time.strftime('%m.%d/%H.%M'))
    logger.info('logging summaries to {}'.format(args.summ_base_dir))

    Learner, Network = ALGORITHMS[args.alg_type]
    #print("Learner is: {}".format(Learner))

    if args.alg_type !='AE':


        network = Network({
            'name': 'shared_vars_network',
            'input_shape': input_shape,
            'num_act': num_actions,
            'args': args
        })
        args.network = Network

    else:

        network_lower = Network({
            'name': 'shared_vars_network_lower',
            'input_shape': input_shape,
            'num_act': num_actions,
            'args': args
        })
        args.network_lower = Network

        network_upper = Network({
            'name': 'shared_vars_network_upper',
            'input_shape': input_shape,
            'num_act': num_actions,
            'args': args
        })
        args.network_upper = Network

    ## initialize visdom server
    args.visdom = visdom.Visdom(port=args.display_port, env='AE DQN')
    #initialize shared variables
    #TODO: !!!!!! only network lower params are being use, should check out if upper is also needed !!!!!!!
    if args.alg_type !='AE':
        args.learning_vars = SharedVars(network.params) #size, step and optimizer
        args.opt_state = SharedVars(
            network.params, opt_type=args.opt_type, lr=args.initial_lr
        ) if args.opt_mode == 'shared' else None
        args.batch_opt_state = SharedVars(
            network.params, opt_type=args.opt_type, lr=args.initial_lr
        ) if args.opt_mode == 'shared' else None
    else:
                #args.learning_vars = SharedVars(network_lower.params) #size, step and optimizer
                args.learning_vars_lower = SharedVars(network_lower.params) #size, step and optimizer
                args.learning_vars_upper = SharedVars(network_upper.params) #size, step and optimizer
                args.opt_state_lower = SharedVars(
                    network_lower.params, opt_type=args.opt_type, lr=args.initial_lr
                )
                args.opt_state_upper = SharedVars(
                    network_upper.params, opt_type=args.opt_type, lr=args.initial_lr
                ) if args.opt_mode == 'shared' else None
                args.batch_opt_state_lower = SharedVars(
                    network_lower.params, opt_type=args.opt_type, lr=args.initial_lr
                )
                args.batch_opt_state_uppper = SharedVars(
                    network_upper.params, opt_type=args.opt_type, lr=args.initial_lr
                ) if args.opt_mode == 'shared' else None


    #TODO: need to refactor so TRPO+GAE doesn't need special treatment
    if args.alg_type in ['trpo', 'trpo-continuous']:
        if args.arch == 'FC': #add timestep feature
            vf_input_shape = [input_shape[0]+1]
        else:
            vf_input_shape = input_shape

        baseline_network = PolicyValueNetwork({
            'name': 'shared_value_network',
            'input_shape': vf_input_shape,
            'num_act': num_actions,
            'args': args
        }, use_policy_head=False)
        args.baseline_vars = SharedVars(baseline_network.params)
        args.vf_input_shape = vf_input_shape

    if args.alg_type in ['q', 'sarsa', 'dueling', 'dqn-cts']:
        args.target_vars = SharedVars(network.params)
        args.target_update_flags = SharedFlags(args.num_actor_learners)
    if args.alg_type in ['dqn-cts', 'a3c-cts', 'a3c-lstm-cts']: #TODO check density_model_update_flags
        args.density_model_update_flags = SharedFlags(args.num_actor_learners)

    if args.alg_type in ['AE']:
        #print("we are in main args.alg_type in [AE]")
        args.target_vars_lower = SharedVars(network_lower.params)
        args.target_vars_upper = SharedVars(network_upper.params)
        args.target_update_flags = SharedFlags(args.num_actor_learners)
        args.density_model_update_flags = SharedFlags(args.num_actor_learners)

    tf.reset_default_graph()
    args.barrier = Barrier(args.num_actor_learners)
    args.global_step = SharedCounter(0)
    #ars.shared_visualizer = Visualizer(args.num_actor_learners) ## TODO to make it shared between the processes
    args.num_actions = num_actions

    cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES')
    num_gpus = 0
    if cuda_visible_devices:
        num_gpus = len(cuda_visible_devices.split())

    #spin up processes and block
    # if (args.visualize == 2): args.visualize = 0
    actor_learners = []
    task_queue = Queue()
    experience_queue = Queue()
    seed = args.seed or np.random.randint(2**32)
    np.random.seed(seed)
    tf.set_random_seed(seed)
    visualize = args.visualize
    for i in range(args.num_actor_learners):
        if (visualize == 2) and (i == args.num_actor_learners - 1):
            args.visualize = 1
        else:
            args.visualize = 0

        args.actor_id = i
        args.device = '/gpu:{}'.format(i % num_gpus) if num_gpus else '/cpu:0'

        args.random_seed = seed + i

        #only used by TRPO
        args.task_queue = task_queue
        args.experience_queue = experience_queue

        args.input_shape = input_shape
        actor_learners.append(Learner(args))
        actor_learners[-1].start()
        if i == 1:
            setup_kill_signal_handler(actor_learners[-1])

    try:
        for t in actor_learners:
            file_name = "myfile_"+str(t)
            with open("grpah", 'w') as file_name:
                wr = csv.writer(file_name, quoting=csv.QUOTE_ALL)
                wr.writerow(t.vis.plot_data['X'])
                wr.writerow(t.vis.plot_data['Y'])
                print ('[%s]' % ', '.join(map(str, t.vis.plot_data['X'])))
            t.join()
    except KeyboardInterrupt:
        #Terminate with extreme prejudice
        for t in actor_learners:

            t.terminate()

    logger.info('All training threads finished!')
    logger.info('Use seed={} to reproduce'.format(seed))
Example #8
0
    def __init__(self, args):

        super(BasePGQLearner, self).__init__(args)

        # args.entropy_regularisation_strength = 0.0
        conf_learning = {
            'name': 'local_learning_{}'.format(self.actor_id),
            'input_shape': self.input_shape,
            'num_act': self.num_actions,
            'args': args
        }

        self.local_network = PolicyValueNetwork(conf_learning)
        self.reset_hidden_state()

        if self.is_master():
            var_list = self.local_network.params
            self.saver = tf.train.Saver(var_list=var_list,
                                        max_to_keep=3,
                                        keep_checkpoint_every_n_hours=2)

        # pgq specific initialization
        self.batch_size = 32
        self.pgq_fraction = args.pgq_fraction
        self.replay_memory = ReplayMemory(args.replay_size)
        self.q_tilde = self.local_network.beta * (
            self.local_network.log_output_layer_pi +
            tf.expand_dims(self.local_network.output_layer_entropy,
                           1)) + self.local_network.output_layer_v

        self.Qi, self.Qi_plus_1 = tf.split(axis=0,
                                           num_or_size_splits=2,
                                           value=self.q_tilde)
        self.V, _ = tf.split(axis=0,
                             num_or_size_splits=2,
                             value=self.local_network.output_layer_v)
        self.log_pi, _ = tf.split(
            axis=0,
            num_or_size_splits=2,
            value=tf.expand_dims(self.local_network.log_output_selected_action,
                                 1))
        self.R = tf.placeholder('float32', [None], name='1-step_reward')

        self.terminal_indicator = tf.placeholder(tf.float32, [None],
                                                 name='terminal_indicator')
        self.max_TQ = self.gamma * tf.reduce_max(
            self.Qi_plus_1, 1) * (1 - self.terminal_indicator)
        self.Q_a = tf.reduce_sum(
            self.Qi * tf.split(axis=0,
                               num_or_size_splits=2,
                               value=self.local_network.selected_action_ph)[0],
            1)

        self.q_objective = -self.pgq_fraction * tf.reduce_mean(
            tf.stop_gradient(self.R + self.max_TQ - self.Q_a) *
            (self.V[:, 0] + self.log_pi[:, 0]))

        self.V_params = self.local_network.params
        self.q_gradients = tf.gradients(self.q_objective, self.V_params)

        if self.local_network.clip_norm_type == 'global':
            self.q_gradients = tf.clip_by_global_norm(
                self.q_gradients, self.local_network.clip_norm)[0]
        elif self.local_network.clip_norm_type == 'local':
            self.q_gradients = [
                tf.clip_by_norm(g, self.local_network.clip_norm)
                for g in self.q_gradients
            ]

        if (self.optimizer_mode == "local"):
            if (self.optimizer_type == "rmsprop"):
                self.batch_opt_st = np.ones(size, dtype=ctypes.c_float)
            else:
                self.batch_opt_st = np.zeros(size, dtype=ctypes.c_float)
        elif (self.optimizer_mode == "shared"):
            self.batch_opt_st = args.batch_opt_state
Example #9
0
class TRPOLearner(BaseA3CLearner):
    '''
	Implementation of Trust Region Policy Optimization + Generalized Advantage Estimation 
	as described in https://arxiv.org/pdf/1506.02438.pdf

	∂'π = F^-1 ∂π where F is the Fischer Information Matrix
	We can't compute F^-1 directly except for very small networks
	so we'll use either conjugate gradient descent to approximate F^-1 ∂π
	'''
    def __init__(self, args):
        args.entropy_regularisation_strength = 0.0
        super(TRPOLearner, self).__init__(args)

        self.batch_size = 512
        self.max_cg_iters = 20
        self.num_epochs = args.num_epochs
        self.cg_damping = args.cg_damping
        self.cg_subsample = args.cg_subsample
        self.max_kl = args.max_kl
        self.max_rollout = args.max_rollout
        self.episodes_per_batch = args.episodes_per_batch
        self.baseline_vars = args.baseline_vars
        self.experience_queue = args.experience_queue
        self.task_queue = args.task_queue

        policy_conf = {
            'name': 'policy_network_{}'.format(self.actor_id),
            'input_shape': self.input_shape,
            'num_act': self.num_actions,
            'args': args
        }
        value_conf = policy_conf.copy()
        value_conf['name'] = 'value_network_{}'.format(self.actor_id)

        self.device = '/gpu:0' if self.is_master() else '/cpu:0'
        with tf.device(self.device):
            #we use separate networks as in the paper since so we don't do damage to the trust region updates
            self.policy_network = PolicyValueNetwork(policy_conf,
                                                     use_value_head=False)
            self.value_network = PolicyValueNetwork(value_conf,
                                                    use_policy_head=False)
            self.local_network = self.policy_network
            self._build_ops()

        if self.is_master():
            var_list = self.policy_network.params + self.value_network.params
            self.saver = tf.train.Saver(var_list=var_list,
                                        max_to_keep=3,
                                        keep_checkpoint_every_n_hours=2)

    def _build_ops(self):
        eps = 1e-10
        self.action_probs = self.policy_network.output_layer_pi
        self.old_action_probs = tf.placeholder(tf.float32,
                                               shape=[None, self.num_actions],
                                               name='old_action_probs')

        action = tf.cast(
            tf.argmax(self.policy_network.selected_action_ph, axis=1),
            tf.int32)

        batch_idx = tf.range(0, tf.shape(action)[0])
        selected_prob = utils.ops.slice_2d(self.action_probs, batch_idx,
                                           action)
        old_selected_prob = utils.ops.slice_2d(self.old_action_probs,
                                               batch_idx, action)

        self.theta = utils.ops.flatten_vars(self.policy_network.params)
        self.policy_loss = -tf.reduce_mean(
            tf.multiply(self.policy_network.adv_actor_ph,
                        selected_prob / old_selected_prob))
        self.pg = utils.ops.flatten_vars(
            tf.gradients(self.policy_loss, self.policy_network.params))

        def discrete_kl_divergence():
            kl = utils.stats.mean_kl_divergence_op(self.old_action_probs,
                                                   self.action_probs)
            kl_firstfixed = tf.reduce_mean(
                tf.reduce_sum(tf.multiply(
                    tf.stop_gradient(self.action_probs),
                    tf.log(
                        tf.stop_gradient(self.action_probs + eps) /
                        (self.action_probs + eps))),
                              axis=1))
            return kl, kl_firstfixed

        def gaussian_kl_divergence():

            kl_firstfixed = self.policy_network.N.kl_divergence(
                tf.stop_gradient(self.policy_network.mu),
                tf.stop_gradient(self.policy_network.sigma))
            return kl, kl_firstfixed

        self.kl, self.kl_firstfixed = discrete_kl_divergence()

        kl_grads = tf.gradients(self.kl_firstfixed, self.policy_network.params)
        flat_kl_grads = utils.ops.flatten_vars(kl_grads)

        self.pg_placeholder = tf.placeholder(
            tf.float32,
            shape=self.pg.get_shape().as_list(),
            name='pg_placeholder')
        self.fullstep, self.neggdotstepdir = self._conjugate_gradient_ops(
            -self.pg_placeholder,
            flat_kl_grads,
            max_iterations=self.max_cg_iters)

    def _conjugate_gradient_ops(self,
                                pg_grads,
                                kl_grads,
                                max_iterations=20,
                                residual_tol=1e-10):
        '''
		Construct conjugate gradient descent algorithm inside computation graph for improved efficiency
		'''
        i0 = tf.constant(0, dtype=tf.int32)
        loop_condition = lambda i, r, p, x, rdotr: tf.logical_and(
            tf.greater(rdotr, residual_tol), tf.less(i, max_iterations))

        def body(i, r, p, x, rdotr):
            fvp = utils.ops.flatten_vars(
                tf.gradients(tf.reduce_sum(tf.stop_gradient(p) * kl_grads),
                             self.policy_network.params))

            z = fvp + self.cg_damping * p

            alpha = rdotr / (tf.reduce_sum(p * z) + 1e-8)
            x += alpha * p
            r -= alpha * z

            new_rdotr = tf.reduce_sum(r * r)
            beta = new_rdotr / (rdotr + 1e-8)
            p = r + beta * p

            new_rdotr = tf.Print(new_rdotr, [i, new_rdotr],
                                 'Iteration / Residual: ')

            return i + 1, r, p, x, new_rdotr

        _, r, p, stepdir, rdotr = tf.while_loop(loop_condition,
                                                body,
                                                loop_vars=[
                                                    i0, pg_grads, pg_grads,
                                                    tf.zeros_like(pg_grads),
                                                    tf.reduce_sum(pg_grads *
                                                                  pg_grads)
                                                ])

        fvp = utils.ops.flatten_vars(
            tf.gradients(tf.reduce_sum(tf.stop_gradient(stepdir) * kl_grads),
                         self.policy_network.params))

        shs = 0.5 * tf.reduce_sum(stepdir * fvp)
        lm = tf.sqrt((shs + 1e-8) / self.max_kl)
        fullstep = stepdir / lm
        neggdotstepdir = tf.reduce_sum(pg_grads * stepdir) / lm

        return fullstep, neggdotstepdir

    def choose_next_action(self, state):
        return self.policy_network.get_action(self.session, state)
        # action_probs = self.session.run(
        # 	self.policy_network.output_layer_pi,
        # 	feed_dict={self.policy_network.input_ph: [state]})

        # action_probs = action_probs.reshape(-1)

        # action_index = self.sample_policy_action(action_probs)
        # new_action = np.zeros([self.num_actions])
        # new_action[action_index] = 1

        # return new_action, action_probs

    def run_minibatches(self, data, *ops):
        outputs = [
            np.zeros(op.get_shape().as_list(), dtype=np.float32) for op in ops
        ]

        data_size = len(data['state'])
        for start in range(0, data_size, self.batch_size):
            end = start + np.minimum(self.batch_size, data_size - start)
            feed_dict = {
                self.policy_network.input_ph: data['state'][start:end],
                self.policy_network.selected_action_ph:
                data['action'][start:end],
                self.policy_network.adv_actor_ph: data['reward'][start:end],
                self.old_action_probs: data['pi'][start:end]
            }
            for i, output_i in enumerate(
                    self.session.run(ops, feed_dict=feed_dict)):
                outputs[i] += output_i * (end - start) / float(data_size)

        return outputs

    def linesearch(self, data, x, fullstep, expected_improve_rate):
        accept_ratio = .1
        backtrack_ratio = .7
        max_backtracks = 15

        fval = self.run_minibatches(data, self.policy_loss)

        for (_n_backtracks, stepfrac) in enumerate(
                backtrack_ratio**np.arange(max_backtracks)):
            xnew = x + stepfrac * fullstep
            self.assign_vars(self.policy_network, xnew)
            newfval, kl = self.run_minibatches(data, self.policy_loss, self.kl)

            improvement = fval - newfval
            logger.debug('Improvement {} / Mean KL {}'.format(improvement, kl))

            # expected_improve = expected_improve_rate * stepfrac
            # ratio = actual_improve / expected_improve
            # if ratio > accept_ratio and actual_improve > 0:
            if kl < self.max_kl and improvement > 0:
                return xnew

        logger.debug('No update')
        return x

    def fit_baseline(self, data):
        data_size = len(data['state'])
        grads = [
            np.zeros(g.get_shape().as_list(), dtype=np.float32)
            for g in self.value_network.get_gradients
        ]

        #permute data in minibatches so we don't introduce bias
        perm = np.random.permutation(data_size)
        for start in range(0, data_size, self.batch_size):
            end = start + np.minimum(self.batch_size, data_size - start)
            batch_idx = perm[start:end]
            feed_dict = {
                self.value_network.input_ph: data['state'][batch_idx],
                self.value_network.critic_target_ph: data['reward'][batch_idx]
            }
            output_i = self.session.run(self.value_network.get_gradients,
                                        feed_dict=feed_dict)

            for i, g in enumerate(output_i):
                grads[i] += g * (end - start) / float(data_size)

        self._apply_gradients_to_shared_memory_vars(grads, self.baseline_vars)

    def predict_values(self, data):
        return self.session.run(
            self.value_network.output_layer_v,
            feed_dict={self.value_network.input_ph: data['state']})[:, 0]

    def update_grads(self, data):
        #we need to compute the policy gradient in minibatches to avoid GPU OOM errors on Atari

        # values = self.predict_values(data)
        # advantages = data['reward'] - values
        # data['reward'] = advantages
        data['reward'] = data['advantage']

        print 'fitting baseline...'
        self.fit_baseline(data)

        print 'running policy gradient...'
        pg = self.run_minibatches(data, self.pg)[0]

        data_size = len(data['state'])
        subsample = np.random.choice(data_size,
                                     int(data_size * self.cg_subsample),
                                     replace=False)
        feed_dict = {
            self.policy_network.input_ph: data['state'][subsample],
            self.policy_network.selected_action_ph: data['action'][subsample],
            self.policy_network.adv_actor_ph: data['reward'][subsample],
            self.old_action_probs: data['pi'][subsample],
            self.pg_placeholder: pg
        }

        print 'running conjugate gradient descent...'
        theta_prev, fullstep, neggdotstepdir = self.session.run(
            [self.theta, self.fullstep, self.neggdotstepdir],
            feed_dict=feed_dict)

        print 'running linesearch...'
        new_theta = self.linesearch(data, theta_prev, fullstep, neggdotstepdir)
        self.assign_vars(self.policy_network, new_theta)

        return self.session.run(self.kl, feed_dict)

    def _run_worker(self):
        while True:
            signal = self.task_queue.get()
            if signal == 'EXIT':
                break

            self.sync_net_with_shared_memory(self.local_network,
                                             self.learning_vars)
            self.sync_net_with_shared_memory(self.value_network,
                                             self.baseline_vars)
            s = self.emulator.get_initial_state()

            data = {
                'state': list(),
                'pi': list(),
                'action': list(),
                'reward': list(),
            }
            episode_over = False
            accumulated_rewards = list()
            while not episode_over and len(
                    accumulated_rewards) < self.max_rollout:
                a, pi = self.choose_next_action(s)
                new_s, reward, episode_over = self.emulator.next(a)
                accumulated_rewards.append(self.rescale_reward(reward))

                data['state'].append(s)
                data['pi'].append(pi)
                data['action'].append(a)

                s = new_s

            mc_returns = list()
            running_total = 0.0
            for r in reversed(accumulated_rewards):
                running_total = r + self.gamma * running_total
                mc_returns.insert(0, running_total)

            data['reward'].extend(mc_returns)
            episode_reward = sum(accumulated_rewards)
            logger.debug('T{} / Episode Reward {}'.format(
                self.actor_id, episode_reward))

            self.experience_queue.put((data, episode_reward))

    def _run_master(self):
        for epoch in range(self.num_epochs):
            data = {
                'state': list(),
                'pi': list(),
                'action': list(),
                'reward': list(),
                'advantage': list(),
            }
            #launch worker tasks
            for i in xrange(self.episodes_per_batch):
                self.task_queue.put(i)

            #collect worker experience
            episode_rewards = list()
            t0 = time.time()
            for _ in xrange(self.episodes_per_batch):
                worker_data, reward = self.experience_queue.get()
                episode_rewards.append(reward)

                values = self.predict_values(worker_data)
                advantages = self.compute_gae(worker_data['reward'],
                                              values.tolist(), 0)
                worker_data['advantage'] = advantages
                for key, value in worker_data.items():
                    data[key].extend(value)

            t1 = time.time()
            kl = self.update_grads({k: np.array(v) for k, v in data.items()})
            self.update_shared_memory()
            t2 = time.time()

            mean_episode_reward = np.array(episode_rewards).mean()
            logger.info(
                'Epoch {} / Mean KL Divergence {} / Mean Reward {} / Experience Time {:.2f}s / Training Time {:.2f}s'
                .format(epoch + 1, kl, mean_episode_reward, t1 - t0, t2 - t1))

    def train(self):
        if self.is_master():
            self._run_master()
            for _ in xrange(self.num_actor_learners):
                self.task_queue.put('EXIT')
        else:
            self._run_worker()