Exemplo n.º 1
0
    def build_policy_network(self):
        if self.args.use_gnn_as_policy:
            self.policy_network = gated_graph_policy_network.GGNN(
                session=self.session,
                name_scope=self.name_scope + '_policy',
                input_size=self.observation_size,
                output_size=self.action_size,
                ob_placeholder=None,
                trainable=True,
                build_network_now=True,
                is_baseline=False,
                placeholder_list=None,
                args=self.args)
            self.raw_obs_placeholder = None
            self.node_info = self.policy_network.get_node_info()
        else:
            self.policy_network = policy_network.policy_network(
                session=self.session,
                name_scope=self.name_scope + '_policy',
                input_size=self.observation_size,
                output_size=self.action_size,
                ob_placeholder=None,
                trainable=True,
                build_network_now=True,
                define_std=True,
                is_baseline=False,
                args=self.args)

        self.fetch_policy_info()
Exemplo n.º 2
0
    def build_policy_network(self, adj_matrix=None, node_attr=None):
        if self.args.nervenetplus:
            assert self.args.use_gnn_as_policy and self.args.use_nervenet

        if self.args.use_gnn_as_policy:
            if self.args.use_nervenet:
                if self.args.nervenetplus:
                    if self.args.tree_net:
                        self.policy_network = treenet_policy.nervenet(
                            session=self.session,
                            name_scope=self.name_scope + '_policy',
                            input_size=self.observation_size,
                            output_size=self.action_size,
                            adj_matrix=adj_matrix,
                            node_attr=node_attr,
                            args=self.args,
                            is_rollout_agent=self.is_rollout_agent)
                        self.step_policy_network = treenet_policy.nervenet(
                            session=self.session,
                            name_scope=self.name_scope + 'step_policy',
                            input_size=self.observation_size,
                            output_size=self.action_size,
                            adj_matrix=self.adj_matrix,
                            node_attr=self.node_attr,
                            args=self.args,
                            is_rollout_agent=True)
                        self.node_info = self.policy_network.get_node_info()

                        self.step_policy_var_list, _ = \
                            self.step_policy_network.get_var_list()

                        self.set_step_policy = \
                            utils.SetPolicyWeights(self.session,
                                                   self.step_policy_var_list)
                    else:
                        self.policy_network = nervenetplus_policy.nervenet(
                            session=self.session,
                            name_scope=self.name_scope + '_policy',
                            input_size=self.observation_size,
                            output_size=self.action_size,
                            adj_matrix=adj_matrix,
                            node_attr=node_attr,
                            args=self.args,
                            is_rollout_agent=self.is_rollout_agent)
                        self.step_policy_network = nervenetplus_policy.nervenet(
                            session=self.session,
                            name_scope=self.name_scope + 'step_policy',
                            input_size=self.observation_size,
                            output_size=self.action_size,
                            adj_matrix=self.adj_matrix,
                            node_attr=self.node_attr,
                            args=self.args,
                            is_rollout_agent=True)
                        self.node_info = self.policy_network.get_node_info()

                        self.step_policy_var_list, _ = \
                            self.step_policy_network.get_var_list()

                        self.set_step_policy = \
                            utils.SetPolicyWeights(self.session,
                                                   self.step_policy_var_list)
                else:
                    self.policy_network = nervenet_policy.nervenet(
                        session=self.session,
                        name_scope=self.name_scope + '_policy',
                        input_size=self.observation_size,
                        output_size=self.action_size,
                        adj_matrix=adj_matrix,
                        node_attr=node_attr,
                        args=self.args)
            else:
                self.policy_network = gated_graph_policy_network.GGNN(
                    session=self.session,
                    name_scope=self.name_scope + '_policy',
                    input_size=self.observation_size,
                    output_size=self.action_size,
                    ob_placeholder=None,
                    trainable=True,
                    build_network_now=True,
                    is_baseline=False,
                    placeholder_list=None,
                    args=self.args)
            self.raw_obs_placeholder = None
            self.node_info = self.policy_network.get_node_info()
        else:
            self.policy_network = policy_network.policy_network(
                session=self.session,
                name_scope=self.name_scope + '_policy',
                input_size=self.observation_size,
                output_size=self.action_size,
                ob_placeholder=None,
                trainable=True,
                build_network_now=True,
                define_std=True,
                is_baseline=False,
                args=self.args)

        # if use the nervenetplus model
        # if self.args.nervenetplus:
        #     # build the action model
        #     with tf.variable_scope('', reuse=True):
        #         self.step_policy_network = nervenetplus_policy.nervenet(
        #             session=self.session,
        #             name_scope=self.name_scope + '_policy',
        #             input_size=self.observation_size,
        #             output_size=self.action_size,
        #             adj_matrix=self.adj_matrix,
        #             node_attr=self.node_attr,
        #             args=self.args,
        #             is_rollout_agent=True
        #         )
        #     self.node_info = self.policy_network.get_node_info()
        self.fetch_policy_info()