Esempio n. 1
0
    def create_tf_graph(self) -> None:
        """
        Builds the tensorflow graph needed for this policy.
        """
        with self.graph.as_default():
            tf.set_random_seed(self.seed)
            _vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
            if len(_vars) > 0:
                # We assume the first thing created in the graph is the Policy. If
                # already populated, don't create more tensors.
                return

            self.create_input_placeholders()
            encoded = self._create_encoder(
                self.visual_in,
                self.processed_vector_in,
                self.h_size,
                self.num_layers,
                self.vis_encode_type,
            )
            if self.use_continuous_act:
                self._create_cc_actor(
                    encoded,
                    self.tanh_squash,
                    self.reparameterize,
                    self.condition_sigma_on_obs,
                )
            else:
                self._create_dc_actor(encoded)
            self.trainable_variables = tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES, scope="policy"
            )
            self.trainable_variables += tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES, scope="lstm"
            )  # LSTMs need to be root scope for Barracuda export

        self.inference_dict = {
            "action": self.output,
            "log_probs": self.all_log_probs,
            "entropy": self.entropy,
        }
        if self.use_continuous_act:
            self.inference_dict["pre_action"] = self.output_pre
        if self.use_recurrent:
            self.inference_dict["memory_out"] = self.memory_out

        # We do an initialize to make the Policy usable out of the box. If an optimizer is needed,
        # it will re-load the full graph
        self.initialize()
        # Create assignment ops for Ghost Trainer
        self.init_load_weights()
Esempio n. 2
0
 def init_load_weights(self):
     with self.graph.as_default():
         _vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
         values = [v.eval(session=self.sess) for v in _vars]
         for var, value in zip(_vars, values):
             assign_ph = tf.placeholder(var.dtype, shape=value.shape)
             self.assign_phs.append(assign_ph)
             self.assign_ops.append(tf.assign(var, assign_ph))
Esempio n. 3
0
 def get_weights(self):
     with self.graph.as_default():
         _vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
         values = [v.eval(session=self.sess) for v in _vars]
         return values
Esempio n. 4
0
 def get_vars(self, scope):
     return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope)
    def start_learning(self, env_manager: EnvManager, inital_weights):
        self._create_output_path(self.output_path)
        # tf.reset_default_graph()
        global_step = 0
        last_brain_behavior_ids: Set[str] = set()
        try:
            # Initial reset
            self._reset_env(env_manager)
            first_step = True
            while self._not_done_training():
                external_brain_behavior_ids = set(
                    env_manager.external_brains.keys())
                new_behavior_ids = external_brain_behavior_ids - last_brain_behavior_ids
                self._create_trainers_and_managers(env_manager,
                                                   new_behavior_ids)
                # Load inital weights
                if (inital_weights is not None and first_step):
                    print("Loading init weights!")
                    # Set weights
                    with self.trainers['Brain'].get_policy(
                            0).graph.as_default():
                        _vars = tf.get_collection(
                            tf.GraphKeys.GLOBAL_VARIABLES)
                        values = [
                            v.eval(session=self.trainers['Brain'].get_policy(
                                0).sess) for v in _vars
                        ]
                        self.trainers['Brain'].get_policy(0).assign_phs = []
                        self.trainers['Brain'].get_policy(0).assign_ops = []
                        for var, value in zip(_vars, values):
                            assign_ph = tf.placeholder(var.dtype,
                                                       shape=value.shape)
                            self.trainers['Brain'].get_policy(
                                0).assign_phs.append(assign_ph)
                            self.trainers['Brain'].get_policy(
                                0).assign_ops.append(tf.assign(var, assign_ph))
                        # print(self.trainers['Brain'].get_policy(0).assign_ops)
                        # print(self.trainers['Brain'].get_policy(0).assign_phs)
                    self.trainers['Brain'].get_policy(0).load_weights(
                        inital_weights)
                    print("Inital weights loaded succesfully!")

                last_brain_behavior_ids = external_brain_behavior_ids
                n_steps = self.advance(env_manager)
                # print("Current weights: " + str(self.trainers['Brain'].get_policy(0).get_weights()[8]))

                for _ in range(n_steps):
                    global_step += 1
                    self.reset_env_if_ready(env_manager, global_step)
            # Stop advancing trainers, Killing trainers
                first_step = False

            self.step = self.trainers['Brain'].step
            self.join_threads()
        except (
                KeyboardInterrupt,
                UnityCommunicationException,
                UnityEnvironmentException,
                UnityCommunicatorStoppedException,
        ) as ex:
            self.join_threads()
            self.logger.info(
                "Learning was interrupted. Please wait while the graph is generated."
            )
            if isinstance(ex, KeyboardInterrupt) or isinstance(
                    ex, UnityCommunicatorStoppedException):
                pass
            else:
                # If the environment failed, we want to make sure to raise
                # the exception so we exit the process with an return code of 1.
                raise ex
        finally:
            # print("Weights after train: " + str(self.trainers['Brain'].get_policy(0).get_weights()[8]))
            # self.weights = self.trainers['Brain'].get_policy(0).get_weights()
            if self.train_model:
                self._save_model()
                self._export_graph()
            return self.trainers['Brain'].get_policy(0).get_weights()