Exemple #1
0
def f_fwgym_get_env(env_id, used_states, instance_index, query_classes,
                    query_class, params):
    """ Function that instantiates the Gym Env 

    ---
    Params 

    env_id              : (str)     Env ID
    instance_index      : (int)     Instance Index used in case of multiple environments
    query_classes       : (TBD)     Query Classes 
    query_class         : (TBD)     Query Class
    params              : (Dict)    Training or Testing Params defined in the related config file in a structure, some of them have an effect on the env internal details
    """
    tp_desc = TrainingParamsDict(params)
    aero = params['aero']['enabled']
    rest_range = params['aero']['windgust']["rest_range"]
    period_range = params['aero']['windgust']['h_range']
    magnitude_max = params['aero']['windgust']['magnitude_max']
    windgust_generator = WindgustGenerator(h_range=period_range,
                                           rest_range=rest_range,
                                           agmax=magnitude_max)
    continuous = tp_desc.get_is_continuous()
    saturation_motor_min = params['quadcopter']['saturation_motor']
    logging.debug(
        f"[f_fwgym_get_env] Instantiating EnvID={env_id}, continuous={continuous}"
    )
    if continuous:
        quadcopter = Quadcopter(
            T=tp_desc.qg_continuous.get_T_episode(),
            dt_commands=tp_desc.qg_continuous.get_dt_command(),
            dt=tp_desc.qg_continuous.get_dt(),
            saturation_motor_min=saturation_motor_min,
            aero=aero,
            windgust_generator=windgust_generator)
        #env.set_continuous(quadcopter=Quadcopter(T=tp_desc.qg_continuous.get_T_episode(), dt_commands=tp_desc.qg_continuous.get_dt_command(), dt=tp_desc.qg_continuous.get_dt()))
    else:
        quadcopter = Quadcopter(
            T=tp_desc.qg_episodic.get_T_episode(),
            dt_commands=tp_desc.qg_episodic.get_dt_command(),
            dt=tp_desc.qg_episodic.get_dt(),
            saturation_motor_min=saturation_motor_min,
            aero=aero,
            windgust_generator=windgust_generator)
        #env.set_episodic(quadcopter=Quadcopter(T=tp_desc.qg_episodic.get_T_episode(), dt_commands=tp_desc.qg_episodic.get_dt_command(), dt=tp_desc.qg_episodic.get_dt()))

    env = GymEnvBase.make(env_id=env_id,
                          instance_index=instance_index,
                          params=params,
                          quadcopter=quadcopter,
                          query_classes=query_classes,
                          query_class=query_class,
                          used_states=used_states)
    env.reset()
    return env
    def test_env_set_z_velocity_angles_reset_function(self):
        """ Tests the possibility to set the Velocity Reset Function from Config Training
        """
        env_desc = EnvDict(env_dict=self.args.env)
        tp_desc = TrainingParamsDict(tp_dict=self.args.training_params)
        self.assertEqual(env_desc.get_env_id(), self.training_config['env']['value'])
        self.assertEqual(self.args.model, self.training_config['model'])
        self.assertEqual(self.args.n_steps, self.training_config['n_steps'])
        self.assertEqual(self.args.training_params, self.training_config['training_params'])

        for i in range(0,3):
            env_id = f'gym_quadcopter:quadcopter-v{i}'
            env = f_fwgym_get_env(
                env_id=env_id, used_states = ['e_p', 'e_q', 'e_r'],
                instance_index=0, query_class='something',
                query_classes={}, params=self.args.training_params
            )
            self.assertEqual(env.params, self.args.training_params)

            env.reset()
            val_min = float(self.args.training_params['quadcopter']['reset_policy']['abs_z']['params'][0])
            val_max = float(self.args.training_params['quadcopter']['reset_policy']['abs_z']['params'][1])
            self.assertTrue(val_min <= env.quadcopter.z <= val_max)

            val_min = float(self.args.training_params['quadcopter']['reset_policy']['velocity_x']['params'][0])
            val_max = float(self.args.training_params['quadcopter']['reset_policy']['velocity_x']['params'][1])
            self.assertTrue(val_min <= env.quadcopter.velocity_x <= val_max)

            val_min = float(self.args.training_params['quadcopter']['reset_policy']['velocity_y']['params'][0])
            val_max = float(self.args.training_params['quadcopter']['reset_policy']['velocity_y']['params'][1])
            self.assertTrue(val_min <= env.quadcopter.velocity_y <= val_max)

            val_min = float(self.args.training_params['quadcopter']['reset_policy']['velocity_z']['params'][0])
            val_max = float(self.args.training_params['quadcopter']['reset_policy']['velocity_z']['params'][1])
            self.assertTrue(val_min <= env.quadcopter.velocity_z <= val_max)

            val_min = float(self.args.training_params['quadcopter']['reset_policy']['abs_roll']['params'][0])
            val_max = float(self.args.training_params['quadcopter']['reset_policy']['abs_roll']['params'][1])
            self.assertTrue(val_min <= env.quadcopter.abs_roll <= val_max)

            val_min = float(self.args.training_params['quadcopter']['reset_policy']['abs_pitch']['params'][0])
            val_max = float(self.args.training_params['quadcopter']['reset_policy']['abs_pitch']['params'][1])
            self.assertTrue(val_min <= env.quadcopter.abs_pitch <= val_max)

            val_min = float(self.args.training_params['quadcopter']['reset_policy']['abs_yaw']['params'][0])
            val_max = float(self.args.training_params['quadcopter']['reset_policy']['abs_yaw']['params'][1])
            self.assertTrue(val_min <= env.quadcopter.abs_yaw <= val_max)

            val_min = float(self.args.training_params['quadcopter']['reset_policy']['rate_roll']['params'][0])
            val_max = float(self.args.training_params['quadcopter']['reset_policy']['rate_roll']['params'][1])
            self.assertTrue(val_min <= env.quadcopter.rate_roll <= val_max)

            val_min = float(self.args.training_params['quadcopter']['reset_policy']['rate_pitch']['params'][0])
            val_max = float(self.args.training_params['quadcopter']['reset_policy']['rate_pitch']['params'][1])
            self.assertTrue(val_min <= env.quadcopter.rate_pitch <= val_max)

            val_min = float(self.args.training_params['quadcopter']['reset_policy']['rate_yaw']['params'][0])
            val_max = float(self.args.training_params['quadcopter']['reset_policy']['rate_yaw']['params'][1])
            self.assertTrue(val_min <= env.quadcopter.rate_yaw <= val_max)
    def f_clw_args_2_state(self, args): 
        """Initialize internal instance state 
        """
        self.model_desc = ModelDict(model_dict=self.args.model)
        self.env_desc = EnvDict(env_dict=self.args.env)
        self.tp_desc = TrainingParamsDict(tp_dict=self.args.training_params)
        self.sp_desc = SwitchParamsDict(self.tp_desc.get_switch_params()) 
        self.query_classes = self.args.query_classes
        self.query_class = self.args.query_class
        self.used_states = self.args.used_states
        self.train_writer = None
        self.param_noise = None

        self.f_clw_instantiate_envs()
        self.n_actions = self.env.action_space.shape[-1]

        self.action_noise = f_fwgym_get_action_noise(noise_dict=self.args.action_noise, n_actions=self.n_actions)
        self.has_switched_training_mode = False
    def test_env(self):
        env_desc = EnvDict(env_dict=self.args.env)
        tp_desc = TrainingParamsDict(tp_dict=self.args.training_params)
        self.assertEqual(env_desc.get_env_id(), self.training_config['env']['value'])
        self.assertEqual(self.args.model, self.training_config['model'])
        self.assertEqual(self.args.n_steps, self.training_config['n_steps'])
        self.assertEqual(self.args.training_params, self.training_config['training_params'])

        env_id = 'gym_quadcopter:quadcopter-v' + str(env_desc.get_env_id())
        env = f_fwgym_get_env(
            env_id=env_id, used_states = ['e_p', 'e_q', 'e_r'],
            instance_index=0, query_class='something', query_classes={},
            params=self.args.training_params
        )
        self.assertEqual(env.params, self.args.training_params)
class Training: 

    def __init__(self): 
        self.best_mean_reward = 0 
        self.n_steps =0 
        self.stats = {"rewards": []} 
        self.i =   0

    def process_end_of_actor_activation(self):
        """ Applies runtime patches to the Stable Baselines source code in order to set the End of Actor Activation
        """
        supported_values = ["tanh", "cbv"]
        if self.args.activation_end_of_actor not in supported_values:
            raise RuntimeError(f"End of Actor Activation {self.args.activation_end_of_actor} not supported")
        if self.args.activation_end_of_actor == "cbv":
            apply_tanh_patch()

    def f_clw_set_interval(self, x): 
        """ Sets the interval related to which the checkpoints are saved  
        """
        logging.debug(f"Operation: SET, Key: self.interval, Value: {x}")
        self.interval = x

    def f_clr_get_interval(self): 
        """ Gets the interval related to which the checkpoints are saved  
        """
        return self.interval

    def f_clw_set_model(self, x): 
        """ Sets the model that is used for the training   
        """
        logging.debug(f"Operation: SET, Key: self.model, Value: {x['model_name']}")
        self.model = x['model']
        self.model_name = x['model_name']

    def f_clr_get_model(self): 
        """ Gets the model that is used for the training   
        """
        logging.debug(f"Operation: GET, Key: self.model, Value: {self.model_name}") 
        return self.model 

    def f_clr_get_feed_dict(self, model): 
        feed_dict = {model.actions: model.stats_sample['actions']}

        for placeholder in [model.action_train_ph, model.action_target, model.action_adapt_noise, model.action_noise_ph]:
            if placeholder is not None:
                feed_dict[placeholder] = model.stats_sample['actions']

        for placeholder in [model.obs_train, model.obs_target, model.obs_adapt_noise, model.obs_noise]:
            if placeholder is not None:
                feed_dict[placeholder] = model.stats_sample['obs']

        return feed_dict


    def f_cb_check_switch(self): 
        if self.sp_desc.get_is_switch_active() and not self.has_switched_training_mode and (self.n_steps / self.args.n_steps) > self.sp_desc.get_time_perc(): 
            if self.sp_desc.get_is_continuous(): 
                temp = "Continuous"
                for x in self.__envs_training: 
                    x.set_continuous(quadcopter=Quadcopter(T=self.tp_desc.qg_continuous.get_T_episode(), dt_commands=self.tp_desc.qg_continuous.get_dt_command(), dt=self.tp_desc.qg_continuous.get_dt()))
            else: 
                temp = "Episodic"
                for x in self.__envs_training: 
                    x.set_episodic(quadcopter=Quadcopter(T=self.tp_desc.qg_episodic.get_T_episode(), dt_commands=self.tp_desc.qg_episodic.get_dt_command(), dt=self.tp_desc.qg_episodic.get_dt()))
            logging.info(f"QUERY MODE GENERATION SWITCH HAPPENED, now it is {temp}")
            self.has_switched_training_mode = True

    def callback(self, _locals, _globals):
        self._debug_callback(model=_locals['self'], sim_time=self.i)
        self._callback_tf_log()
        if (self.n_steps + 1) % self.f_clr_get_interval() == 0:
            self.f_cb_check_switch()
            self.i += 1
            full_checkpoint_id = int(self.model_desc.get_checkpoint_id())+int(self.i)
            logging.info(f"Checkpoint ID: Internal={self.i}, Full={full_checkpoint_id}, n_timesteps: {self.n_steps}")
            temp=self._save_model_stable_baselines(model=_locals['self'], cp_id=full_checkpoint_id)
            self._save_model_sherlock(temp)

            if self.train_saver is not None: 
                self.train_saver.save(sess=self.model.sess, save_path=f"{self.args.log_dir_tensorboard}/cp", global_step=self.i)
            if(self.args.save_as_tf): 
                path_save_cp = os.path.join(self.args.log_dir_tensorboard, f"cp-{self.i}")
                print(f"Saving Tensorflow Checkpoint in {path_save_cp}")
                self._save_model(path_save_cp)

            evaluation = f_model_2_evaluation(model=_locals['self'], env=self.env_test)
            quadcopter = self.__envs_training[0].quadcopter
            temp_plot_fn = f_iofsw_eval_2_plot(
                evaluation=evaluation, checkpoint_id=full_checkpoint_id,
                iteration_time=0, plots_dir=self.args.plots_dir,
                saturated=quadcopter.saturated, not_saturated=quadcopter.not_saturated)
            self.stats['rewards'].append(evaluation['re'])

        self.n_steps += 1
        # Returning False will stop training early
        return True

    def _debug_callback(self, model, sim_time): 
        if(self.args.debug_is_active): 
            if(self.args.debug_model_describe): 
                print(self._describe_model())
            if(self.args.debug_try_save_all_vars): 
                tf_path = f"{self.args.models_dir}/tf_quadcopter-{self.i}-desc"
                if not os.path.exists(tf_path): os.mkdir(tf_path)
                tf_testname_model = "debug_vars_all.json"
                tf_full_path = tf_path + "/" + tf_testname_model
                res = ""
                for v in tf.get_default_graph().as_graph_def().node: 
                    res += f"{v.name}\n"
                print(f"Trying to save debug data in {tf_full_path}")
                with open(tf_full_path, "w") as f: 
                    f.write(self._describe_model())
            if(self.args.debug_try_save_trainable_vars): 
                tf_path = f"{self.args.models_dir}/tf_quadcopter-{self.i}-desc"
                if not os.path.exists(tf_path): os.mkdir(tf_path)
                tf_testname_model = "debug_vars_trainable.json"
                tf_full_path = tf_path + "/" + tf_testname_model
                res = ""
                for v in tf.trainable_variables(): 
                    res += f"{v.name}\n"
                print(f"Trying to save debug data in {tf_full_path}")
                with open(tf_full_path, "w") as f: 
                    f.write(self._describe_model())
            if(self.args.debug_try_save_graph): 
                tf_path = f"{self.args.models_dir}/tf_quadcopter-{self.i}-desc"
                if not os.path.exists(tf_path): os.mkdir(tf_path)
                tf_testname_model = "debug_graph.json"
                tf_full_path = tf_path + "/" + tf_testname_model
                graph = tf.get_default_graph().as_graph_def()
                json_graph = json_format.MessageToJson(graph)
                print(f"Trying to save debug data in {tf_full_path}")
                with open(tf_full_path, "w") as f: 
                    f.write(json_graph)
            if(self.args.debug_try_save_weights): 
                tf_path = f"{self.args.models_dir}/tf_quadcopter-{self.i}-desc"
                if not os.path.exists(tf_path): os.mkdir(tf_path)
                tf_testname_model = "debug_weights.json"
                tf_full_path = tf_path + "/" + tf_testname_model
                weights = tf.trainable_variables()
                weights_vals = tf.get_default_session().run(weights)
                print(dir(tf.get_default_session().graph))
                print(f"Trying to save debug data in {tf_full_path}")
                with open(tf_full_path, "w") as f: 
                    f.write(str(weights_vals))

            if self.args.debug_show_tensors_active: 
                ops = []
                for e in self.args.debug_show_tensors_list: 
                    temp = getattr(model, e)
                    ops.append(temp)
                values = model.sess.run(ops, feed_dict=f_fwtf_get_feed_dict(model))
                for v in values: 
                    print(f"v.shape = {v.shape}\nv.value={v}\n\n")





    
    def _save_model(self, export_dir): 
        builder = tf.saved_model.builder.SavedModelBuilder(export_dir) 
        builder.add_meta_graph_and_variables(self.model.sess, [tf.saved_model.tag_constants.TRAINING])
        builder.save()

    def _save_model_stable_baselines(self, model, cp_id): 
        # Evaluate policy training performance
        path = f"{self.args.models_dir}/quadcopter-{cp_id}{self.args.suffix}"
        logging.info(f"SAVING CURRENT MODEL, Model SAVED at {path}")
        model.save(path)
        return path + '.pkl'

    def _save_model_sherlock(self, filename): 
        output_filename = filename + '.sherlock'
        params = get_stable_baseline_file_params(filename)
        print(f"Saving Sherlock Format File {output_filename}")
        with open( output_filename, 'w' ) as file_ : 
            file_.write(architectures.export.get_sherlock_format(model_desc=self.model_desc, params=params))

    def _describe_model(self): 
        res = f"Model.Graph Type={type(self.model.graph)}\nContent={dir(self.model.graph)}\n\n\n"
        res += f"Analysing {len(tf.get_default_graph().as_graph_def().node)} nodes \n"
        res += f"Graph Def = {tf.get_default_graph().as_graph_def()}\n"
        res += f"---------\n"
        for v in tf.get_default_graph().as_graph_def().node: 
            res += f"{v.name}\n"
        res += f"-----------\n"
        return res

    def _get_action_noise(self, noise_dict, n_actions): 
        if noise_dict['name'] == 'OrnsteinUhlenbeck': 
            return OrnsteinUhlenbeckActionNoise(mean=float(noise_dict['mu'])*np.ones(n_actions), sigma=float(noise_dict['sigma']) * np.ones(n_actions))
        else: 
            raise RuntimeError(f"Unrecognized Noise Model {noise_dict['name']}")


    def _args2str(self,a): 
        return f"step={a.step}\n" \
               f"env={a.env}\n" \
               f"verbose={str(a.verbose)}\n" \
               f"save_plots={str(a.save_plots)}\n" \
               f"suffix={a.suffix}\n" \
               f"model={json.dumps(a.model)}\n" \
               f"activation={a.activation}\n" \
               f"action_noise={json.dumps(a.action_noise)}\n" \
               f"n_steps={a.n_steps}\n" \
               f"model_dir={a.models_dir}\n" \
               f"plots_dir={a.plots_dir}\n"

    def _get_plot_rewards(self): 
        fig=plt.figure("Rewards")
        plt.plot(self.stats["rewards"])
        fig.suptitle('Reward')
        plt.xlabel('time')
        plt.ylabel('reward')
        return plt

    def _write_graph_def_for_tb(self, graph_def, LOGDIR): 
        """ TODO: Remove 
        """
        train_writer = tf.summary.FileWriter(LOGDIR)
        train_writer.add_graph(graph_def)
        train_writer.flush()
        train_writer.close()


    @property
    def sb_tb_log_active(self):
        """ Returns if native Stable Baseline Logging is active
        """
        return self.args.logging['tensorflow']['stable_baselines_native']['active']

    @property
    def sb_tb_log_dir(self):
        """ Returns the Stable Baseline TF Log Dir 
        """
        return self.args.log_dir_tensorboard if self.sb_tb_log_active else None


    def f_clr_instantiate_model(self, m): 
        res_model = None
        model_name = m.get_model_name()
        if m.get_actor_feature_extractor_type() == 'standard':
            pk = dict(act_fun=activations[self.args.activation])
        else:
            pk = dict(act_fun=activations[self.args.activation], layers=m.get_actor_feature_extractor_architecture())
        model_params = {
            'policy': MlpPolicy,
            'env': self.env,
            'verbose': int(self.args.verbose),
            'policy_kwargs': pk,
            'tensorboard_log': self.sb_tb_log_dir,
            'full_tensorboard_log': self.sb_tb_log_active
        }
        if m.get_actor_feature_extractor_name() != 'mlp':
            raise NotImplementedError(f"Exporting Policy Type {model_desc.get_actor_feature_extractor_name()} is unsupported at the moment")
        if model_name == 'ddpg':
            algo = DDPG
            model_params['param_noise'] = self.param_noise
            model_params['action_noise'] = self.action_noise
            model_params['render_eval'] = True
            model_params['policy'] = ddpg_policies.MlpPolicy
        elif model_name == 'trpo':
            algo = TRPO
            model_params['policy'] = common.MlpPolicy
        elif model_name == 'ppo':
            algo = PPO2
            model_params['policy'] = common.MlpPolicy
        elif model_name == 'td3':
            algo = TD3
            model_params['policy'] = td3_MlpPolicy
        elif model_name == 'sac':
            algo = SAC
            model_params['policy'] = sac_MlpPolicy
        model = algo(**model_params)
        # Tensorboard #
        tf.io.write_graph(model.graph, self.args.log_dir_tensorboard, "model.pbtxt")
        if self.train_writer is not None: self.train_writer.add_graph(model.graph)
        if self.train_writer is not None: self.train_writer.flush()
        logging.info(f"Instantiated Model Name={res_model}, policy={type(model_params['policy'])}, pk={pk}")
        return {"model": model, "model_name": model_name.upper()}

    def f_clw_instantiate_envs(self): 
        """ Instantiate both the Training and Test Gym Env 
        - They provide the same dynamical model and the same reward 
        """
        temp = 'gym_quadcopter:quadcopter-v' + str(self.env_desc.get_env_id())
        # TODO FIXME: Some models cannot handle multiple envs.
        N = self.env_desc.get_n_envs()
        if N < 1: 
            raise RuntimeError(f"Got NumEnvs needs to be >=1 but got NumEnvs={N}")
        logging.info(f"[SETUP] Creating {N} Training Environments - START")

        # Instantiating all the Envs and storing them into a private var 
        self.__envs_training = [f_fwgym_get_env(
            env_id=temp, used_states=self.used_states, instance_index=i,
            query_classes=self.query_classes, query_class=self.query_class,
            params=self.args.training_params
        ) for i in range(N)]

        # Passing references to previously created envs 
        self.env = DummyVecEnv([lambda: self.__envs_training[i] for i in range(N)]) 
        logging.info(f"[SETUP] Creating {N} Training Environments - DONE")
        logging.info(f"[SETUP] Creating 1 Test Environments - START")
        self.env_test = f_fwgym_get_env(
            env_id=temp, used_states=self.used_states, instance_index=0,
            query_classes=self.query_classes, query_class=self.query_class,
            params=self.args.testing_params
        )
        logging.info(f"[SETUP] Creating 1 Test Environments - DONE")

    def f_clw_args_2_state(self, args): 
        """Initialize internal instance state 
        """
        self.model_desc = ModelDict(model_dict=self.args.model)
        self.env_desc = EnvDict(env_dict=self.args.env)
        self.tp_desc = TrainingParamsDict(tp_dict=self.args.training_params)
        self.sp_desc = SwitchParamsDict(self.tp_desc.get_switch_params()) 
        self.query_classes = self.args.query_classes
        self.query_class = self.args.query_class
        self.used_states = self.args.used_states
        self.train_writer = None
        self.param_noise = None

        self.f_clw_instantiate_envs()
        self.n_actions = self.env.action_space.shape[-1]

        self.action_noise = f_fwgym_get_action_noise(noise_dict=self.args.action_noise, n_actions=self.n_actions)
        self.has_switched_training_mode = False


    def f_fwtfw_init(self): 
        """Initialize TF Environment 
        """
        tfl.set_verbosity(tfl.ERROR)
        os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'


    def get_global_summary(self): 
        return {"ModelName": self.model_desc.get_model_name(), "Continuous": str(self.tp_desc.get_is_continuous()), "Total_Training_Iterations": self.args.n_steps, "Iterations_Per_Checkpoint": self.args.iterations_checkpoint, "Env" : { "ID" : self.env_desc.get_env_id(), "Num_Envs" : self.env_desc.get_n_envs() }}

    def _add_tf_logs(self):
        """Adds the additional Tensorflow Logs to the standard Stable Baselines ones
        """
        with self.model.graph.as_default():
            # Conditional Logging for Summary
            if self.args.logging['tensorflow']['summary']['active']:
                tf.summary.text('Env Summary', tf.convert_to_tensor(str(self.env)))
            
            # Conditional Logging for the Stable Baselines Tensors specified in the list
            if self.args.logging['tensorflow']['stable_baselines_tensors']['active']:
                for e in self.args.logging['tensorflow']['stable_baselines_tensors']['list']:
                    tf.summary.scalar(f"Custom_SB_Log/{e}", tf.reduce_mean(getattr(self.model, e)))

            # Conditional Logging for the Tensorflow Tensors specified in the list
            if self.args.logging['tensorflow']['tensorflow_tensors']['active']:
                for e in self.args.logging['tensorflow']['tensorflow_tensors']['list']:
                    tf.summary.histogram(f"Custom_TF_Log/{e}", tf.get_default_graph().get_tensor_by_name(e))

            # Conditional Logging for Quadcopter Framework Events
            if self.args.logging['tensorflow']['events']['active']:
                if 'on_step' in self.args.logging['tensorflow']['events']['list']:
                    tf.summary.text(f'EnvStep{self.n_steps}', tf.convert_to_tensor(self.env.env_method('get_on_step_log')))

            # Merge all of the added summaries 
            self.model.summary = tf.summary.merge_all()


    def _callback_tf_log(self):
        with self.model.graph.as_default():
            if self.args.logging['tensorflow']['events']['active']:
                if 'on_step' in self.args.logging['tensorflow']['events']['list']:
                    tf.summary.text('EnvStep', tf.convert_to_tensor(self.env.env_method('get_on_step_log')))
                    self.model.summary = tf.summary.merge_all()

    def run_training(self, args):
        """ Training Function
        """
        # Use standard log just for the initial setup
        # Set the log used during training
        self.args = args
        self.process_end_of_actor_activation()
        self.f_clw_args_2_state(args)
        logging.info(f"Train Arguments\n{self._args2str(self.args)}") 
        logging.info(f"Writing Tensorboard Log to {self.args.log_dir_tensorboard}")
        logging.info(f"Start training at {dt.now().strftime('%Y%m%d_%H%M')}")
        self.f_fwtfw_init()
        if self.model_desc.get_is_load():
            # TODO: Fix this part 
            path = self.model_desc.get_checkpoint_path()
            model_name = self.model_desc.get_model_name()
            logging.info(f"LOADING MODEL at {path}")
            if model_name == "ddpg":
                self.model = DDPG.load(path, self.env)
            elif model_name == "ppo":
                self.model = PPO2.load(path, self.env)
            elif model_name == "trpo":
                self.model = TRPO.load(path, self.env)
            elif model_name == "td3":
                self.model = TD3.load(path, self.env)
            elif model_name == 'sac':
                self.model = SAC.load(path, self.env)
        else:
            # the noise objects for DDPG
            self.f_clw_set_model(self.f_clr_instantiate_model(m=self.model_desc))

            
        self.f_clw_set_interval(self.args.iterations_checkpoint)

        if self.args.save_tf_checkpoint: 
            with self.model.graph.as_default(): 
                self.train_saver = tf.compat.v1.train.Saver()
        else: 
            self.train_saver = None 



        self.i = 0
        # Implemented in 
        # https://github.com/hill-a/stable-baselines/blob/master/stable_baselines/ddpg/ddpg.py#L807

        logging.info(f"GLOBAL SUMMARY: {self.get_global_summary()}")
        self._add_tf_logs()
        self.model.learn(total_timesteps=int(self.args.n_steps), callback=self.callback)
        logging.info(f"Training Finished after {self.n_steps} iterations saving {self.i} intermediate checkpoints")
        logging.info(f"Saving Final Model in Stable Baseline Checkpoint")
        temp=self._save_model_stable_baselines(model=self.model, cp_id="final")

        print(f"Exporting Actor from Final Model in Stable Baseline Checkpoint as Sherlock Format")
        self._save_model_sherlock(temp)

        if self.train_writer is not None: self.train_writer.close()

        plt = self._get_plot_rewards()
        now = dt.now() 
        plt.savefig(f"{self.args.plots_dir}/reward_{now.strftime('%Y%m%d_%H%M%S')}.png")
        return True