예제 #1
0
    def _create_session_tf(self, task_parameters: TaskParameters):
        import tensorflow as tf
        config = tf.ConfigProto()
        config.allow_soft_placement = True  # allow placing ops on cpu if they are not fit for gpu
        config.gpu_options.allow_growth = True  # allow the gpu memory allocated for the worker to grow if needed
        # config.gpu_options.per_process_gpu_memory_fraction = 0.2
        config.intra_op_parallelism_threads = 1
        config.inter_op_parallelism_threads = 1

        if isinstance(task_parameters, DistributedTaskParameters):
            # the distributed tensorflow setting
            from rl_coach.architectures.tensorflow_components.distributed_tf_utils import create_monitored_session
            if hasattr(self.task_parameters,
                       'checkpoint_restore_path') and self.task_parameters.checkpoint_restore_path:
                checkpoint_dir = os.path.join(task_parameters.experiment_path, 'checkpoint')
                if os.path.exists(checkpoint_dir):
                    remove_tree(checkpoint_dir)
                # in the locally distributed case, checkpoints are always restored from a directory (and not from a
                # file)
                copy_tree(task_parameters.checkpoint_restore_path, checkpoint_dir)
            else:
                checkpoint_dir = task_parameters.checkpoint_save_dir

            self.set_session(create_monitored_session(target=task_parameters.worker_target,
                                                      task_index=task_parameters.task_index,
                                                      checkpoint_dir=checkpoint_dir,
                                                      checkpoint_save_secs=task_parameters.checkpoint_save_secs,
                                                      config=config))
        else:
            # regular session
            self.set_session(tf.Session(config=config))

        # the TF graph is static, and therefore is saved once - in the beginning of the experiment
        if hasattr(self.task_parameters, 'checkpoint_save_dir') and self.task_parameters.checkpoint_save_dir:
            self.save_graph()
예제 #2
0
    def _create_session_tf(self, task_parameters: TaskParameters):
        import tensorflow as tf
        config = tf.ConfigProto()
        config.allow_soft_placement = True  # allow placing ops on cpu if they are not fit for gpu
        config.gpu_options.allow_growth = True  # allow the gpu memory allocated for the worker to grow if needed
        # config.gpu_options.per_process_gpu_memory_fraction = 0.2
        config.intra_op_parallelism_threads = 1
        config.inter_op_parallelism_threads = 1

        if isinstance(task_parameters, DistributedTaskParameters):
            # the distributed tensorflow setting
            from rl_coach.architectures.tensorflow_components.distributed_tf_utils import create_monitored_session
            if hasattr(self.task_parameters, 'checkpoint_restore_dir'
                       ) and self.task_parameters.checkpoint_restore_dir:
                checkpoint_dir = os.path.join(task_parameters.experiment_path,
                                              'checkpoint')
                if os.path.exists(checkpoint_dir):
                    remove_tree(checkpoint_dir)
                copy_tree(task_parameters.checkpoint_restore_dir,
                          checkpoint_dir)
            else:
                checkpoint_dir = task_parameters.checkpoint_save_dir

            self.sess = create_monitored_session(
                target=task_parameters.worker_target,
                task_index=task_parameters.task_index,
                checkpoint_dir=checkpoint_dir,
                checkpoint_save_secs=task_parameters.checkpoint_save_secs,
                config=config)
            # set the session for all the modules
            self.set_session(self.sess)
        else:
            self.variables_to_restore = tf.global_variables()
            # self.variables_to_restore = [v for v in self.variables_to_restore if '/online' in v.name] TODO: is this necessary?
            self.checkpoint_saver = tf.train.Saver(self.variables_to_restore)

            # regular session
            self.sess = tf.Session(config=config)

            # set the session for all the modules
            self.set_session(self.sess)

            # restore from checkpoint if given
            self.restore_checkpoint()

        # the TF graph is static, and therefore is saved once - in the beginning of the experiment
        if hasattr(self.task_parameters, 'checkpoint_save_dir'
                   ) and self.task_parameters.checkpoint_save_dir:
            self.save_graph()