Example #1
0
    def build_model(self):
        # build the environments for the agent
        self.build_env()
        self.build_session()
        self.build_policy_network()
        self.fetch_policy_info()

        self.session.run(tf.global_variables_initializer())

        self.set_policy = utils.SetPolicyWeights(self.session,
                                                 self.policy_var_list)
        self.get_policy = utils.GetPolicyWeights(self.session,
                                                 self.policy_var_list)
Example #2
0
    def build_models(self):
        '''
            @brief:
                this is the function where the rollout agents and optimization
                agent build their networks, set up the placeholders, and gather
                the variable list.
        '''
        # make sure that the agent has a session
        self.build_session()

        # set the summary writer
        self.summary_writer = summary_handler.gym_summary_handler(
            self.session,
            self.get_experiment_name(),
            enable=self.args.write_summary,
            summary_dir='1115'  #self.args.output_dir
        )

        # build the policy network and baseline network
        self.build_policy_network()

        # the baseline function to reduce the variance
        self.build_baseline_network()

        # the training op and graphs
        self.build_ppo_update_op()
        self.update_parameters = self.update_ppo_parameters

        # init the network parameters (xavier initializer)
        self.session.run(tf.global_variables_initializer())

        # the set weight policy ops
        self.get_policy = utils.GetPolicyWeights(self.session,
                                                 self.policy_var_list)

        # prepare the feed_dict info for the ppo minibatches
        self.prepare_feed_dict_map()

        # prepared the init kl divergence if needed
        self.current_kl_lambda = 1
    def build_models(self):
        '''
            @brief:
                this is the function where the rollout agents and optimization
                agent build their networks, set up the placeholders, and gather
                the variable list.
        '''
        # make sure that the agent has a session
        self.build_session()

        # the baseline function to reduce the variance
        self.build_baseline_network()

        # init the network parameters (xavier initializer)
        self.session.run(tf.global_variables_initializer())

        # prepare the feed_dict info for the ppo minibatches
        self._receive_idx, self._send_idx, self._node_type_idx, \
            self._inverse_node_type_idx, self._batch_size_int = \
            self.baseline_network.get_gnn_idx_placeholder()
        self._num_nodes_ph = self.baseline_network._num_nodes_placeholder

        self._input_parameters = \
            self.baseline_network.get_input_parameters_placeholder()
        self._target_returns = \
            self.baseline_network.get_target_return_placeholder()
        self._vpred = self.baseline_network.get_vpred_placeholder()

        # prepared the init kl divergence if needed
        self.current_kl_lambda = 1

        self.policy_var_list, self.all_policy_var_list = \
            self.baseline_network.get_var_list()

        self.get_policy = \
            utils.GetPolicyWeights(self.session, self.policy_var_list)
        self.set_policy = \
            utils.SetPolicyWeights(self.session, self.policy_var_list)
        self.initial_weights = self.get_policy()
    def build_models(self, received_data=None):
        '''
            @brief:
                this is the function where the rollout agents and optimization
                agent build their networks, set up the placeholders, and gather
                the variable list.
        '''
        # make sure that the agent has a session
        self.start_time = time.time()
        self.build_session()

        self.build_policy_network(adj_matrix=self.adj_matrix,
                                  node_attr=self.node_attr)

        self.baseline_network, self.target_return_placeholder, \
            self.raw_obs_placeholder = \
            agent_util.build_baseline_network(
                self.args, self.session, self.name_scope, self.observation_size,
                self.gnn_placeholder_list, self.obs_placeholder
            )
        # for key in tf.trainable_variables():
        #     print(key.name)
        # from util import fpdb; fpdb.fpdb().set_trace()

        # the training op and graphs
        self.build_ppo_update_op()
        self.update_parameters = self.update_ppo_parameters

        # init the network parameters (xavier initializer)
        self.session.run(tf.global_variables_initializer())

        # the set weight policy ops
        self.get_policy = \
            utils.GetPolicyWeights(self.session, self.policy_var_list)
        self.set_policy = \
            utils.SetPolicyWeights(self.session, self.policy_var_list)

        if received_data is not None:
            # set the running_mean and network weights here
            if received_data['reset']:
                # no inheriting from the parents!
                logger.info('Not Inheriting from parents!')
            else:
                if self.args.use_gnn_as_policy:
                    current_species_format = self.get_species_info()
                    processed_species_info = species2species.process_inherited_info(
                        raw_species_info=received_data,
                        current_species_format=current_species_format,
                        is_nervenet=self.args.use_gnn_as_policy)
                    self.set_policy(processed_species_info['policy_weights'])
                    self.set_running_means(
                        processed_species_info['running_mean_info'])
                    self.current_lr = received_data['lr']
                else:
                    # old species of the fc baselines
                    if received_data[
                            'SpcID'] > 0 and self.args.fc_amortized_fitness:
                        self.set_policy(received_data['policy_weights'])
                        self.set_running_means(
                            received_data['running_mean_info'])

        # prepare the feed_dict info for the ppo minibatches
        self.prepare_feed_dict_map()

        # prepared the init kl divergence if needed
        self.current_kl_lambda = 1