def create_tf_graph(self) -> None: """ Builds the tensorflow graph needed for this policy. """ with self.graph.as_default(): tf.set_random_seed(self.seed) _vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if len(_vars) > 0: # We assume the first thing created in the graph is the Policy. If # already populated, don't create more tensors. return self.create_input_placeholders() encoded = self._create_encoder( self.visual_in, self.processed_vector_in, self.h_size, self.num_layers, self.vis_encode_type, ) if self.use_continuous_act: self._create_cc_actor( encoded, self.tanh_squash, self.reparameterize, self.condition_sigma_on_obs, ) else: self._create_dc_actor(encoded) self.trainable_variables = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope="policy" ) self.trainable_variables += tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope="lstm" ) # LSTMs need to be root scope for Barracuda export self.inference_dict = { "action": self.output, "log_probs": self.all_log_probs, "entropy": self.entropy, } if self.use_continuous_act: self.inference_dict["pre_action"] = self.output_pre if self.use_recurrent: self.inference_dict["memory_out"] = self.memory_out # We do an initialize to make the Policy usable out of the box. If an optimizer is needed, # it will re-load the full graph self.initialize() # Create assignment ops for Ghost Trainer self.init_load_weights()
def init_load_weights(self): with self.graph.as_default(): _vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) values = [v.eval(session=self.sess) for v in _vars] for var, value in zip(_vars, values): assign_ph = tf.placeholder(var.dtype, shape=value.shape) self.assign_phs.append(assign_ph) self.assign_ops.append(tf.assign(var, assign_ph))
def get_weights(self): with self.graph.as_default(): _vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) values = [v.eval(session=self.sess) for v in _vars] return values
def get_vars(self, scope): return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope)
def start_learning(self, env_manager: EnvManager, inital_weights): self._create_output_path(self.output_path) # tf.reset_default_graph() global_step = 0 last_brain_behavior_ids: Set[str] = set() try: # Initial reset self._reset_env(env_manager) first_step = True while self._not_done_training(): external_brain_behavior_ids = set( env_manager.external_brains.keys()) new_behavior_ids = external_brain_behavior_ids - last_brain_behavior_ids self._create_trainers_and_managers(env_manager, new_behavior_ids) # Load inital weights if (inital_weights is not None and first_step): print("Loading init weights!") # Set weights with self.trainers['Brain'].get_policy( 0).graph.as_default(): _vars = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES) values = [ v.eval(session=self.trainers['Brain'].get_policy( 0).sess) for v in _vars ] self.trainers['Brain'].get_policy(0).assign_phs = [] self.trainers['Brain'].get_policy(0).assign_ops = [] for var, value in zip(_vars, values): assign_ph = tf.placeholder(var.dtype, shape=value.shape) self.trainers['Brain'].get_policy( 0).assign_phs.append(assign_ph) self.trainers['Brain'].get_policy( 0).assign_ops.append(tf.assign(var, assign_ph)) # print(self.trainers['Brain'].get_policy(0).assign_ops) # print(self.trainers['Brain'].get_policy(0).assign_phs) self.trainers['Brain'].get_policy(0).load_weights( inital_weights) print("Inital weights loaded succesfully!") last_brain_behavior_ids = external_brain_behavior_ids n_steps = self.advance(env_manager) # print("Current weights: " + str(self.trainers['Brain'].get_policy(0).get_weights()[8])) for _ in range(n_steps): global_step += 1 self.reset_env_if_ready(env_manager, global_step) # Stop advancing trainers, Killing trainers first_step = False self.step = self.trainers['Brain'].step self.join_threads() except ( KeyboardInterrupt, UnityCommunicationException, UnityEnvironmentException, UnityCommunicatorStoppedException, ) as ex: self.join_threads() self.logger.info( "Learning was interrupted. Please wait while the graph is generated." ) if isinstance(ex, KeyboardInterrupt) or isinstance( ex, UnityCommunicatorStoppedException): pass else: # If the environment failed, we want to make sure to raise # the exception so we exit the process with an return code of 1. raise ex finally: # print("Weights after train: " + str(self.trainers['Brain'].get_policy(0).get_weights()[8])) # self.weights = self.trainers['Brain'].get_policy(0).get_weights() if self.train_model: self._save_model() self._export_graph() return self.trainers['Brain'].get_policy(0).get_weights()