Пример #1
0
class DQNModel(Model):
    default_config = DQNModelConfig

    def __init__(self, config, scope, network_builder=None):
        """
        Training logic for DQN.

        :param config: Configuration dict
        """
        super(DQNModel, self).__init__(config, scope)

        self.action_count = self.config.actions
        self.tau = self.config.tau
        self.gamma = self.config.gamma
        # self.batch_size = self.config.batch_size

        self.double_dqn = self.config.double_dqn

        self.clip_value = None
        if self.config.clip_gradients:
            self.clip_value = self.config.clip_value

        if self.config.deterministic_mode:
            self.random = global_seed()
        else:
            self.random = np.random.RandomState()

        self.target_network_update = []

        # output layer
        output_layer_config = [{"type": "linear", "num_outputs": self.config.actions, "trainable": True}]

        # Input placeholders
        self.state_shape = tuple(self.config.state_shape)
        self.state = tf.placeholder(tf.float32, (None, None) + self.state_shape, name="state")
        self.next_states = tf.placeholder(tf.float32, (None, None) + self.state_shape,
                                          name="next_states")
        self.terminals = tf.placeholder(tf.float32, (None, None), name='terminals')
        self.rewards = tf.placeholder(tf.float32, (None, None), name='rewards')

        if network_builder is None:
            network_builder = NeuralNetwork.layered_network(self.config.network_layers + output_layer_config)

        self.training_network = NeuralNetwork(network_builder, [self.state], episode_length=self.episode_length,
                                              scope=self.scope + 'training')
        self.target_network = NeuralNetwork(network_builder, [self.next_states], episode_length=self.episode_length,
                                            scope=self.scope + 'target')

        self.training_internal_states = self.training_network.internal_state_inits
        self.target_internal_states = self.target_network.internal_state_inits

        self.training_output = self.training_network.output
        self.target_output = self.target_network.output

        # Create training operations
        self.create_training_operations()

        self.init_op = tf.global_variables_initializer()

        self.saver = tf.train.Saver()
        self.writer = tf.summary.FileWriter('logs', graph=tf.get_default_graph())
        self.session.run(self.init_op)

    def get_action(self, state, episode=1):
        """
        Returns the predicted action for a given state.

        :param state: State tensor
        :param episode: Current episode
        :return: action number
        """
        epsilon = self.exploration(episode, self.total_states)

        if self.random.random_sample() < epsilon:
            action = self.random.randint(0, self.action_count)
        else:
            fetches = [self.dqn_action]
            fetches.extend(self.training_internal_states)
            fetches.extend(self.target_internal_states)

            feed_dict = {self.episode_length: [1], self.state: [(state,)]}

            feed_dict.update({internal_state: self.training_network.internal_state_inits[n] for n, internal_state in
                              enumerate(self.training_network.internal_state_inputs)})
            feed_dict.update({internal_state: self.target_network.internal_state_inits[n] for n, internal_state in
                              enumerate(self.target_network.internal_state_inputs)})

            fetched = self.session.run(fetches=fetches, feed_dict=feed_dict)
            # First element of output list is action
            action = fetched[0][0]

            # Update optional internal states, e.g. LSTM cells
            self.training_internal_states = fetched[1:len(self.training_internal_states)]
            self.target_internal_states = fetched[1 + len(self.training_internal_states):]

        self.total_states += 1

        return action

    def update(self, batch):
        """
        Perform a single training step and updates the target network.

        :param batch: Mini batch to use for training
        :return: void
        """
        self.logger.debug('Updating DQN model..')

        # Compute estimated future value
        float_terminals = batch['terminals'].astype(float)
        y = self.get_target_values(batch['next_states'])

        q_targets = batch['rewards'] + (1. - float_terminals) \
                                       * self.gamma * y

        feed_dict = {
            self.episode_length: [len(batch['rewards'])],
            self.q_targets: q_targets,
            self.actions: [batch['actions']],
            self.state: [batch['states']]
        }

        fetches = [self.optimize_op, self.training_output]
        fetches.extend(self.training_network.internal_state_outputs)
        fetches.extend(self.target_network.internal_state_outputs)

        for n, internal_state in enumerate(self.training_network.internal_state_inputs):
            feed_dict[internal_state] = self.training_internal_states[n]

        for n, internal_state in enumerate(self.target_network.internal_state_inputs):
            feed_dict[internal_state] = self.target_internal_states[n]

        fetched = self.session.run(fetches, feed_dict)

        # Update internal state list, e.g. or LSTM
        self.training_internal_states = fetched[2:len(self.training_internal_states)]
        self.target_internal_states = fetched[2 + len(self.training_internal_states):]

    def get_variables(self):
        return self.training_network.get_variables()

    def assign_variables(self, values):
        assign_variables_ops = [variable.assign(value) for variable, value in zip(self.get_variables(), values)]
        self.session.run(tf.group(assign_variables_ops))

    def get_gradients(self):
        return self.grads_and_vars

    def apply_gradients(self, grads_and_vars):
        apply_gradients_op = self.optimizer.apply_gradients(grads_and_vars)
        self.session.run(apply_gradients_op)

    def create_training_operations(self):
        """
        Create graph operations for loss computation and
        target network updates.

        """
        with tf.name_scope(self.scope):
            with tf.name_scope("predict"):
                self.dqn_action = tf.argmax(self.training_output, axis=2, name='dqn_action')

            with tf.name_scope("targets"):
                if self.double_dqn:
                    selector = tf.one_hot(self.dqn_action, self.action_count, name='selector')
                    self.target_values = tf.reduce_sum(tf.multiply(self.target_output, selector), axis=2,
                                                       name='target_values')
                else:
                    self.target_values = tf.reduce_max(self.target_output, axis=2,
                                                       name='target_values')

            with tf.name_scope("update"):
                # Self.q_targets gets fed the actual observed rewards and expected future rewards
                self.q_targets = tf.placeholder(tf.float32, (None, None), name='q_targets')

                # Self.actions gets fed the actual actions that have been taken
                self.actions = tf.placeholder(tf.int32, (None, None), name='actions')

                # One_hot tensor of the actions that have been taken
                actions_one_hot = tf.one_hot(self.actions, self.action_count, 1.0, 0.0, name='action_one_hot')

                # Training output, so we get the expected rewards given the actual states and actions
                q_values_actions_taken = tf.reduce_sum(self.training_output * actions_one_hot, axis=2,
                                                       name='q_acted')

                # Surrogate loss as the mean squared error between actual observed rewards and expected rewards
                delta = self.q_targets - q_values_actions_taken

                # If gradient clipping is used, calculate the huber loss
                if self.config.clip_gradients:
                    huber_loss = tf.where(tf.abs(delta) < self.clip_value, 0.5 * tf.square(delta), tf.abs(delta) - 0.5)
                    self.loss = tf.reduce_mean(huber_loss, name='compute_surrogate_loss')
                else:
                    self.loss = tf.reduce_mean(tf.square(delta), name='compute_surrogate_loss')

                self.grads_and_vars = self.optimizer.compute_gradients(self.loss)
                self.optimize_op = self.optimizer.apply_gradients(self.grads_and_vars)

                # Update target network with update weight tau
                with tf.name_scope("update_target"):
                    for v_source, v_target in zip(self.training_network.variables, self.target_network.variables):
                        update = v_target.assign_sub(self.tau * (v_target - v_source))
                        self.target_network_update.append(update)

    def get_target_values(self, next_states):
        """
        Estimate of next state Q values.
        :param next_states:
        :return:
        """
        if self.double_dqn:
            return self.session.run(self.target_values, {self.state: [next_states], self.next_states: [next_states]})
        else:
            return self.session.run(self.target_values, {self.next_states: [next_states]})

    def update_target_network(self):
        """
        Updates target network.

        :return:
        """
        self.session.run(self.target_network_update)
Пример #2
0
class DQNModel(Model):
    default_config = DQNModelConfig

    def __init__(self, config, scope, define_network=None):
        """
        Training logic for DQN.

        :param config: Configuration dict
        """
        super(DQNModel, self).__init__(config, scope)

        self.action_count = self.config.actions
        self.tau = self.config.tau
        self.gamma = self.config.gamma
        self.batch_size = self.config.batch_size

        self.double_dqn = self.config.double_dqn

        self.clip_value = None
        if self.config.clip_gradients:
            self.clip_value = self.config.clip_value

        if self.config.deterministic_mode:
            self.random = global_seed()
        else:
            self.random = np.random.RandomState()

        self.target_network_update = []

        # output layer
        output_layer_config = [{
            "type": "linear",
            "num_outputs": self.config.actions,
            "trainable": True
        }]

        self.device = self.config.tf_device
        if self.device == 'replica':
            self.device = tf.train.replica_device_setter(
                ps_tasks=1, worker_device=self.config.tf_worker_device)

        with tf.device(self.device):
            # Input placeholders
            self.state = tf.placeholder(tf.float32,
                                        self.batch_shape +
                                        list(self.config.state_shape),
                                        name="state")
            self.next_states = tf.placeholder(tf.float32,
                                              self.batch_shape +
                                              list(self.config.state_shape),
                                              name="next_states")
            self.terminals = tf.placeholder(tf.float32,
                                            self.batch_shape,
                                            name='terminals')
            self.rewards = tf.placeholder(tf.float32,
                                          self.batch_shape,
                                          name='rewards')

            if define_network is None:
                define_network = NeuralNetwork.layered_network(
                    self.config.network_layers + output_layer_config)

            self.training_model = NeuralNetwork(define_network, [self.state],
                                                scope=self.scope + 'training')
            self.target_model = NeuralNetwork(define_network,
                                              [self.next_states],
                                              scope=self.scope + 'target')

            self.training_output = self.training_model.get_output()
            self.target_output = self.target_model.get_output()

            # Create training operations
            self.create_training_operations()
            self.optimizer = tf.train.RMSPropOptimizer(self.alpha,
                                                       momentum=0.95,
                                                       epsilon=0.01)

        self.training_output = self.training_model.get_output()
        self.target_output = self.target_model.get_output()

        self.init_op = tf.global_variables_initializer()

        self.saver = tf.train.Saver()
        self.writer = tf.summary.FileWriter('logs',
                                            graph=tf.get_default_graph())

    def initialize(self):
        self.session.run(self.init_op)

    def get_action(self, state, episode=1):
        """
        Returns the predicted action for a given state.

        :param state: State tensor
        :param episode: Current episode
        :return: action number
        """

        epsilon = self.exploration(episode, self.total_states)

        if self.random.random_sample() < epsilon:
            action = self.random.randint(0, self.action_count)
        else:
            action = self.session.run(self.dqn_action,
                                      {self.state: [state]})[0]

        self.total_states += 1
        return action

    def update(self, batch):
        """
        Perform a single training step and updates the target network.

        :param batch: Mini batch to use for training
        :return: void
        """

        # Compute estimated future value
        float_terminals = batch['terminals'].astype(float)
        q_targets = batch['rewards'] + (1. - float_terminals) \
                                     * self.gamma * self.get_target_values(batch['next_states'])

        self.session.run(
            [self.optimize_op, self.training_output], {
                self.q_targets: q_targets,
                self.actions: batch['actions'],
                self.state: batch['states']
            })

    def get_variables(self):
        return self.training_model.get_variables()

    def assign_variables(self, values):
        assign_variables_ops = [
            variable.assign(value)
            for variable, value in zip(self.get_variables(), values)
        ]
        self.session.run(tf.group(assign_variables_ops))

    def get_gradients(self):
        return self.grads_and_vars

    def apply_gradients(self, grads_and_vars):
        apply_gradients_op = self.optimizer.apply_gradients(grads_and_vars)
        self.session.run(apply_gradients_op)

    def create_training_operations(self):
        """
        Create graph operations for compute_surrogate_loss computation and
        target network updates.

        :return:
        """
        with tf.name_scope(self.scope):
            with tf.name_scope("predict"):
                self.dqn_action = tf.argmax(self.training_output,
                                            axis=1,
                                            name='dqn_action')

            with tf.name_scope("targets"):
                if self.double_dqn:
                    selector = tf.one_hot(self.dqn_action,
                                          self.action_count,
                                          name='selector')
                    self.target_values = tf.reduce_sum(tf.multiply(
                        self.target_output, selector),
                                                       axis=1,
                                                       name='target_values')
                else:
                    self.target_values = tf.reduce_max(self.target_output,
                                                       axis=1,
                                                       name='target_values')

            with tf.name_scope("update"):
                # Self.q_targets gets fed the actual observed rewards and expected future rewards
                self.q_targets = tf.placeholder(tf.float32, [None],
                                                name='q_targets')

                # Self.actions gets fed the actual actions that have been taken
                self.actions = tf.placeholder(tf.int32, [None], name='actions')

                # One_hot tensor of the actions that have been taken
                actions_one_hot = tf.one_hot(self.actions,
                                             self.action_count,
                                             1.0,
                                             0.0,
                                             name='action_one_hot')

                # Training output, so we get the expected rewards given the actual states and actions
                q_values_actions_taken = tf.reduce_sum(self.training_output *
                                                       actions_one_hot,
                                                       axis=1,
                                                       name='q_acted')

                # Surrogate loss as the mean squared error between actual observed rewards and expected rewards
                delta = self.q_targets - q_values_actions_taken

                # if gradient clipping is used, calculate the huber loss
                if self.config.clip_gradients:
                    huber_loss = tf.where(
                        tf.abs(delta) < self.clip_value,
                        0.5 * tf.square(delta),
                        tf.abs(delta) - 0.5)
                    self.loss = tf.reduce_mean(huber_loss,
                                               name='compute_surrogate_loss')
                else:
                    self.loss = tf.reduce_mean(tf.square(delta),
                                               name='compute_surrogate_loss')

                self.grads_and_vars = self.optimizer.compute_gradients(
                    self.loss)
                self.optimize_op = self.optimizer.apply_gradients(
                    self.grads_and_vars)

            # Update target network with update weight tau
            with tf.name_scope("update_target"):
                for v_source, v_target in zip(
                        self.training_model.get_variables(),
                        self.target_model.get_variables()):
                    update = v_target.assign_sub(self.tau *
                                                 (v_target - v_source))
                    self.target_network_update.append(update)

    def get_target_values(self, next_states):
        """
        Estimate of next state Q values.
        :param next_states:
        :return:
        """
        if self.double_dqn:
            return self.session.run(self.target_values, {
                self.state: next_states,
                self.next_states: next_states
            })
        else:
            return self.session.run(self.target_values,
                                    {self.next_states: next_states})

    def update_target_network(self):
        """
        Updates target network.

        :return:
        """
        self.session.run(self.target_network_update)
Пример #3
0
class NAFModel(Model):
    default_config = NAFModelConfig

    def __init__(self, config, scope, define_network=None):
        """
        Training logic for NAFs.

        :param config: Configuration parameters
        """
        super(NAFModel, self).__init__(config, scope)
        self.action_count = self.config.actions
        self.tau = self.config.tau
        self.epsilon = self.config.epsilon
        self.gamma = self.config.gamma
        self.batch_size = self.config.batch_size

        if self.config.deterministic_mode:
            self.random = global_seed()
        else:
            self.random = np.random.RandomState()

        self.state = tf.placeholder(tf.float32, self.batch_shape + list(self.config.state_shape), name="state")
        self.next_states = tf.placeholder(tf.float32, self.batch_shape + list(self.config.state_shape),
                                          name="next_states")

        self.actions = tf.placeholder(tf.float32, [None, self.action_count], name='actions')
        self.terminals = tf.placeholder(tf.float32, [None], name='terminals')
        self.rewards = tf.placeholder(tf.float32, [None], name='rewards')
        self.q_targets = tf.placeholder(tf.float32, [None], name='q_targets')
        self.target_network_update = []
        self.episode = 0

        # Get hidden layers from network generator, then add NAF outputs, same for target network
        scope = '' if self.config.tf_scope is None else self.config.tf_scope + '-'

        if define_network is None:
            define_network = NeuralNetwork.layered_network(self.config.network_layers)

        self.training_model = NeuralNetwork(define_network, [self.state], scope=scope + 'training')
        self.target_model = NeuralNetwork(define_network, [self.next_states], scope=scope + 'target')

        # Create output fields
        self.training_v, self.mu, self.advantage, self.q, self.training_output_vars = self.create_outputs(
            self.training_model.get_output(), 'outputs_training')
        self.target_v, _, _, _, self.target_output_vars = self.create_outputs(self.target_model.get_output(),
                                                                              'outputs_target')
        self.create_training_operations()
        self.saver = tf.train.Saver()
        self.session.run(tf.global_variables_initializer())

    def get_action(self, state, episode=1):
        """
        Returns naf action(s) as given by the mean output of the network.

        :param state: Current state
        :param episode: Current episode
        :return: action
        """
        action = self.session.run(self.mu, {self.state: [state]})[0] + self.exploration(episode, self.total_states)
        self.total_states += 1

        return action

    def update(self, batch):
        """
        Executes a NAF update on a training batch.

        :param batch:=
        :return:
        """
        float_terminals = batch['terminals'].astype(float)

        q_targets = batch['rewards'] + (1. - float_terminals) * self.gamma * np.squeeze(
            self.get_target_value_estimate(batch['next_states']))

        self.session.run([self.optimize_op, self.loss, self.training_v, self.advantage, self.q], {
            self.q_targets: q_targets,
            self.actions: batch['actions'],
            self.state: batch['states']})

    def create_outputs(self, last_hidden_layer, scope):
        """
        Creates NAF specific outputs.

        :param last_hidden_layer: Points to last hidden layer
        :param scope: TF name scope

        :return Output variables and all TF variables created in this scope
        """

        with tf.name_scope(scope):
            # State-value function
            v = linear(last_hidden_layer, {'num_outputs': 1, 'weights_regularizer': self.config.weights_regularizer,
                                           'weights_regularizer_args': [self.config.weights_regularizer_args]}, scope + 'v')

            # Action outputs
            mu = linear(last_hidden_layer, {'num_outputs': self.action_count, 'weights_regularizer': self.config.weights_regularizer,
                                            'weights_regularizer_args': [self.config.weights_regularizer_args]}, scope + 'mu')

            # Advantage computation
            # Network outputs entries of lower triangular matrix L
            lower_triangular_size = int(self.action_count * (self.action_count + 1) / 2)
            l_entries = linear(last_hidden_layer, {'num_outputs': lower_triangular_size,
                                                   'weights_regularizer': self.config.weights_regularizer,
                                                   'weights_regularizer_args': [self.config.weights_regularizer_args]},
                               scope + 'l')

            # Iteratively construct matrix. Extra verbose comment here
            l_rows = []
            offset = 0

            for i in xrange(self.action_count):
                # Diagonal elements are exponentiated, otherwise gradient often 0
                # Slice out lower triangular entries from flat representation through moving offset

                diagonal = tf.exp(tf.slice(l_entries, (0, offset), (-1, 1)))

                n = self.action_count - i - 1
                # Slice out non-zero non-diagonal entries, - 1 because we already took the diagonal
                non_diagonal = tf.slice(l_entries, (0, offset + 1), (-1, n))

                # Fill up row with zeros
                row = tf.pad(tf.concat(axis=1, values=(diagonal, non_diagonal)), ((0, 0), (i, 0)))
                offset += (self.action_count - i)
                l_rows.append(row)

            # Stack rows to matrix
            l_matrix = tf.transpose(tf.stack(l_rows, axis=1), (0, 2, 1))

            # P = LL^T
            p_matrix = tf.matmul(l_matrix, tf.transpose(l_matrix, (0, 2, 1)))

            # Need to adjust dimensions to multiply with P.
            action_diff = tf.expand_dims(self.actions - mu, -1)

            # A = -0.5 (a - mu)P(a - mu)
            advantage = -0.5 * tf.matmul(tf.transpose(action_diff, [0, 2, 1]),
                                               tf.matmul(p_matrix, action_diff))
            advantage = tf.reshape(advantage, [-1, 1])

            with tf.name_scope('q_values'):
                # Q = A + V
                q_value = v + advantage

        # Get all variables under this scope for target network update
        return v, mu, advantage, q_value, get_variables(scope)

    def create_training_operations(self):
        """
        NAF update logic.
        """

        with tf.name_scope("update"):
            # MSE
            self.loss = tf.reduce_mean(tf.squared_difference(self.q_targets, tf.squeeze(self.q)),
                                       name='loss')
            self.optimize_op = self.optimizer.minimize(self.loss)

        with tf.name_scope("update_target"):
            # Combine hidden layer variables and output layer variables
            self.training_vars = self.training_model.get_variables() + self.training_output_vars
            self.target_vars = self.target_model.get_variables() + self.target_output_vars

            for v_source, v_target in zip(self.training_vars, self.target_vars):
                update = v_target.assign_sub(self.tau * (v_target - v_source))

                self.target_network_update.append(update)

    def get_target_value_estimate(self, next_states):
        """
        Estimate of next state V value through target network.

        :param next_states:
        :return:
        """

        return self.session.run(self.target_v, {self.next_states: next_states})

    def update_target_network(self):
        """
        Updates target network.

        :return:
        """
        self.session.run(self.target_network_update)
class DistributedPGModel(object):
    default_config = {}

    def __init__(self,
                 config,
                 scope,
                 task_index,
                 cluster_spec,
                 define_network=None):
        """

        A distributed agent must synchronise local and global parameters under different
        scopes.

        :param config: Configuration parameters
        :param scope: TensorFlow scope
        """

        self.session = None
        self.saver = None
        self.config = create_config(config, default=self.default_config)
        self.scope = scope
        self.task_index = task_index
        self.batch_size = self.config.batch_size
        self.action_count = self.config.actions
        self.use_gae = self.config.use_gae
        self.gae_lambda = self.config.gae_lambda

        self.gamma = self.config.gamma
        self.continuous = self.config.continuous
        self.normalize_advantage = self.config.normalise_advantage

        if self.config.deterministic_mode:
            self.random = global_seed()
        else:
            self.random = np.random.RandomState()

        if define_network is None:
            self.define_network = NeuralNetwork.layered_network(
                self.config.network_layers)
        else:
            self.define_network = define_network

        # This is the scope used to prefix variable creation for distributed TensorFlow
        self.batch_shape = [None]
        self.deterministic_mode = config.get('deterministic_mode', False)
        self.alpha = config.get('alpha', 0.001)
        self.optimizer = None

        self.worker_device = "/job:worker/task:{}/cpu:0".format(task_index)

        with tf.device(
                tf.train.replica_device_setter(
                    1, worker_device=self.worker_device,
                    cluster=cluster_spec)):
            with tf.variable_scope("global"):
                self.global_state = tf.placeholder(
                    tf.float32,
                    self.batch_shape + list(self.config.state_shape),
                    name="global_state")

                self.global_network = NeuralNetwork(self.define_network,
                                                    [self.global_state])
                self.global_step = tf.get_variable(
                    "global_step", [],
                    tf.int32,
                    initializer=tf.constant_initializer(0, dtype=tf.int32),
                    trainable=False)

                self.global_prev_action_means = tf.placeholder(
                    tf.float32, [None, self.action_count], name='prev_actions')

                if self.continuous:
                    self.global_policy = GaussianPolicy(
                        self.global_network, self.session, self.global_state,
                        self.random, self.action_count, 'gaussian_policy')
                    self.global_prev_action_log_stds = tf.placeholder(
                        tf.float32, [None, self.action_count])

                    self.global_prev_dist = dict(
                        policy_output=self.global_prev_action_means,
                        policy_log_std=self.global_prev_action_log_stds)

                else:
                    self.global_policy = CategoricalOneHotPolicy(
                        self.global_network, self.session, self.global_state,
                        self.random, self.action_count, 'categorical_policy')
                    self.global_prev_dist = dict(
                        policy_output=self.global_prev_action_means)

                # Probability distribution used in the current policy
                self.global_baseline_value_function = LinearValueFunction()

            # self.optimizer = config.get('optimizer')
            # self.optimizer_args = config.get('optimizer_args', [])
            # self.optimizer_kwargs = config.get('optimizer_kwargs', {})

        exploration = config.get('exploration')
        if not exploration:
            self.exploration = exploration_mode['constant'](self, 0)
        else:
            args = config.get('exploration_args', [])
            kwargs = config.get('exploration_kwargs', {})
            self.exploration = exploration_mode[exploration](self, *args,
                                                             **kwargs)

        self.create_training_operations()

    def set_session(self, session):
        self.session = session

        # Session in policy was still 'None' when
        # we initialised policy, hence need to set again
        self.policy.session = session

    def create_training_operations(self):
        """
        Currently a duplicate of the pg agent logic, to be made generic later to allow
        all models to be executed asynchronously/distributed seamlessly.

        """
        # TODO rewrite agent logic so core update logic can be composed into
        # TODO distributed logic

        with tf.device(self.worker_device):
            with tf.variable_scope("local"):
                self.state = tf.placeholder(tf.float32,
                                            self.batch_shape +
                                            list(self.config.state_shape),
                                            name="state")
                self.prev_action_means = tf.placeholder(
                    tf.float32, [None, self.action_count], name='prev_actions')

                self.local_network = NeuralNetwork(self.define_network,
                                                   [self.state])
                # TODO possibly problematic, check
                self.local_step = self.global_step

                if self.continuous:
                    self.policy = GaussianPolicy(self.local_network,
                                                 self.session, self.state,
                                                 self.random,
                                                 self.action_count,
                                                 'gaussian_policy')
                    self.prev_action_log_stds = tf.placeholder(
                        tf.float32, [None, self.action_count])

                    self.prev_dist = dict(
                        policy_output=self.prev_action_means,
                        policy_log_std=self.prev_action_log_stds)

                else:
                    self.policy = CategoricalOneHotPolicy(
                        self.local_network, self.session, self.state,
                        self.random, self.action_count, 'categorical_policy')
                    self.prev_dist = dict(policy_output=self.prev_action_means)

                # Probability distribution used in the current policy
                self.baseline_value_function = LinearValueFunction()

            self.actions = tf.placeholder(tf.float32,
                                          [None, self.action_count],
                                          name='actions')
            self.advantage = tf.placeholder(tf.float32,
                                            shape=[None, 1],
                                            name='advantage')

            self.dist = self.policy.get_distribution()
            self.log_probabilities = self.dist.log_prob(
                self.policy.get_policy_variables(), self.actions)

            # Concise: Get log likelihood of actions, weigh by advantages, compute gradient on that
            self.loss = -tf.reduce_mean(
                self.log_probabilities * self.advantage, name="loss_op")

            self.gradients = tf.gradients(self.loss,
                                          self.local_network.get_variables())

            grad_var_list = list(
                zip(self.gradients, self.global_network.get_variables()))

            global_step_inc = self.global_step.assign_add(
                tf.shape(self.state)[0])

            self.assign_global_to_local = tf.group(*[
                v1.assign(v2)
                for v1, v2 in zip(self.local_network.get_variables(),
                                  self.global_network.get_variables())
            ])

            # TODO write summaries
            # self.summary_writer = tf.summary.FileWriter('log' + "_%d" % self.task_index)
            if not self.optimizer:
                self.optimizer = tf.train.AdamOptimizer(self.alpha)

            else:
                optimizer_cls = get_function(self.optimizer)
                self.optimizer = optimizer_cls(self.alpha,
                                               *self.optimizer_args,
                                               **self.optimizer_kwargs)

            self.optimize_op = tf.group(
                self.optimizer.apply_gradients(grad_var_list), global_step_inc)

    def get_action(self, state, episode=1):
        return self.policy.sample(state)

    def update(self, batch):
        """
        Get global parameters, compute update, then send results to parameter server.
        :param batch:
        :return:
        """

        self.compute_gae_advantage(batch, self.gamma, self.gae_lambda)

        # Update linear value function for baseline prediction
        self.baseline_value_function.fit(batch)

        # Merge episode inputs into single arrays
        _, _, actions, batch_advantage, states = self.merge_episodes(batch)

        self.session.run(
            [self.optimize_op, self.global_step], {
                self.state: states,
                self.actions: actions,
                self.advantage: batch_advantage
            })

    def get_global_step(self):
        """
        Returns global step to coordinator.
        :return:
        """
        return self.session.run(self.global_step)

    def sync_global_to_local(self):
        """
        Copy shared global weights to local network.

        """
        self.session.run(self.assign_global_to_local)

    def load_model(self, path):
        self.saver.restore(self.session, path)

    def save_model(self, path):
        self.saver.save(self.session, path)

    # TODO remove this duplication, move to util or let distributed agent
    # have a pg agent as a field
    def merge_episodes(self, batch):
        """
        Merge episodes of a batch into single input variables.

        :param batch:
        :return:
        """
        if self.continuous:
            action_log_stds = np.concatenate(
                [path['action_log_stds'] for path in batch])
            action_log_stds = np.expand_dims(action_log_stds, axis=1)
        else:
            action_log_stds = None

        action_means = np.concatenate([path['action_means'] for path in batch])
        actions = np.concatenate([path['actions'] for path in batch])
        batch_advantage = np.concatenate([path["advantage"] for path in batch])

        if self.normalize_advantage:
            batch_advantage = zero_mean_unit_variance(batch_advantage)

        batch_advantage = np.expand_dims(batch_advantage, axis=1)
        states = np.concatenate([path['states'] for path in batch])

        return action_log_stds, action_means, actions, batch_advantage, states

    # TODO duplicate code -> refactor from pg model
    def compute_gae_advantage(self, batch, gamma, gae_lambda, use_gae=False):
        """
        Expects a batch containing at least one episode, sets advantages according to use_gae.

        :param batch: Sequence of observations for at least one episode.
        """
        for episode in batch:
            baseline = self.baseline_value_function.predict(episode)
            if episode['terminated']:
                adjusted_baseline = np.append(baseline, [0])
            else:
                adjusted_baseline = np.append(baseline, baseline[-1])

            episode['returns'] = discount(episode['rewards'], gamma)

            if use_gae:
                deltas = episode['rewards'] + gamma * adjusted_baseline[
                    1:] - adjusted_baseline[:-1]
                episode['advantage'] = discount(deltas, gamma * gae_lambda)
            else:
                episode['advantage'] = episode['returns'] - baseline