Exemple #1
0
import tensorflow as tf

from relaax.common.algorithms import subgraph
from relaax.common.algorithms.lib import graph
from relaax.common.algorithms.lib import utils


class SharedParameters(subgraph.Subgraph):
    def build_graph(self):
        # Build graph
        sg_global_step = graph.GlobalStep()
        sg_initialize = graph.Initialize()

        # Expose public API
        self.op_n_step = self.Op(sg_global_step.n)
        self.op_initialize = self.Op(sg_initialize)


class Model(subgraph.Subgraph):
    def build_graph(self):
        # Build graph
        state = graph.Placeholder(np.float32, shape=(2, ))
        reverse = graph.TfNode(tf.reverse(state.node, [0]))

        # Expose public API
        self.op_get_action = self.Op(reverse, state=state)


if __name__ == '__main__':
    utils.assemble_and_show_graphs(SharedParameters, Model)
Exemple #2
0
        self.op_get_q_value = self.Op(sg_network.output.node,
                                      state=sg_network.ph_state)
        self.op_get_q_target_value = self.Op(
            sg_target_network.output.node,
            next_state=sg_target_network.ph_state)

        self.op_get_action = self.Op(sg_get_action,
                                     local_step=sg_get_action.ph_local_step,
                                     q_value=sg_get_action.ph_q_value)

        sg_initialize = graph.Initialize()

        feeds = dict(state=sg_network.ph_state,
                     reward=sg_loss.ph_reward,
                     action=sg_loss.ph_action,
                     terminal=sg_loss.ph_terminal,
                     q_next_target=sg_loss.ph_q_next_target,
                     q_next=sg_loss.ph_q_next)

        self.op_compute_gradients = self.Op(sg_gradients_calc.calculate,
                                            **feeds)

        self.op_update_target_weights = self.Op(sg_update_target_weights)

        self.op_initialize = self.Op(sg_initialize)


if __name__ == '__main__':
    utils.assemble_and_show_graphs(GlobalServer, AgentModel)
Exemple #3
0
        self.op_get_weights_signed = self.Ops(sg_weights, sg_update_step.n)

        self.op_apply_gradients = self.Ops(sg_gradients.apply, sg_update_step.increment,
                                           gradients=sg_gradients.ph_gradients,
                                           increment=sg_update_step.ph_increment)

        self.op_get_weights_flatten = self.Op(sg_get_weights_flatten)
        self.op_set_weights_flatten = self.Op(sg_set_weights_flatten, value=sg_set_weights_flatten.ph_value)

        # Gradient combining routines
        self.op_submit_gradients = self.Call(graph.get_gradients_apply_routine(dppo_config.config))

        self.op_initialize = self.Op(sg_initialize)


# Weights of the policy are shared across
# all agents and stored on the parameter server
class SharedParameters(subgraph.Subgraph):
    def build_graph(self):
        sg_model = Model()

        sg_policy_shared = SharedWeights(sg_model.actor.weights)
        sg_value_func_shared = SharedWeights(sg_model.critic.weights)

        self.policy = sg_policy_shared
        self.value_func = sg_value_func_shared


if __name__ == '__main__':
    utils.assemble_and_show_graphs(SharedParameters, Model(assemble_model=True))
Exemple #4
0
                                 ph_state=self.sg_network.ph_state)
        self.op_get_action_and_value = self.Ops(
            self.sg_network.pi,
            self.sg_network.vi,
            self.sg_network.lstm_state,
            ph_state=self.sg_network.ph_state,
            ph_goal=self.sg_network.ph_goal,
            ph_initial_lstm_state=self.sg_network.ph_initial_lstm_state,
            ph_step_size=self.sg_network.ph_step_size)
        self.op_get_action = self.Ops(  # use for exploitation
            self.sg_network.pi,
            self.sg_network.lstm_state,
            ph_state=self.sg_network.ph_state,
            ph_goal=self.sg_network.ph_goal,
            ph_initial_lstm_state=self.sg_network.ph_initial_lstm_state,
            ph_step_size=self.sg_network.ph_step_size)

        # with lstm state freezes
        self.op_get_value_zt = self.Ops(
            self.sg_network.perception,
            self.sg_network.vi,
            self.sg_network.lstm_state,
            ph_state=self.sg_network.ph_state,
            ph_initial_lstm_state=self.sg_network.ph_initial_lstm_state,
            ph_step_size=self.sg_network.ph_step_size)


if __name__ == '__main__':
    utils.assemble_and_show_graphs(GlobalManagerNetwork, LocalManagerNetwork,
                                   GlobalWorkerNetwork, LocalWorkerNetwork)