Example #1
0
    def _eval(self, observations, actions):
        feeds = {
            self._observations_ph: observations,
            self._actions_ph: actions
        }

        return tf_utils.get_default_session().run(self._output, feeds)
Example #2
0
    def _train(self, env, policy, pool):
        """Perform RL training.

        Args:
            env (`rllab.Env`): Environment used for training
            policy (`Policy`): Policy used for training
            pool (`PoolBase`): Sample pool to add samples to
        """
        self._init_training()
        self.sampler.initialize(env, policy, pool)

        evaluation_env = deep_clone(env) if self._eval_n_episodes else None

        with tf_utils.get_default_session().as_default():
            gt.rename_root('RLAlgorithm')
            gt.reset()
            gt.set_def_unique(False)

            for epoch in gt.timed_for(range(self._n_epochs + 1),
                                      save_itrs=True):
                logger.push_prefix('Epoch #%d | ' % epoch)

                for t in range(self._epoch_length):
                    self.sampler.sample()
                    if not self.sampler.batch_ready():
                        continue
                    gt.stamp('sample')

                    for i in range(self._n_train_repeat):
                        self._do_training(iteration=t +
                                          epoch * self._epoch_length,
                                          batch=self.sampler.random_batch())
                    gt.stamp('train')

                self._evaluate(policy, evaluation_env)
                gt.stamp('eval')

                params = self.get_snapshot(epoch)
                logger.save_itr_params(epoch, params)

                time_itrs = gt.get_times().stamps.itrs
                time_eval = time_itrs['eval'][-1]
                time_total = gt.get_times().total
                time_train = time_itrs.get('train', [0])[-1]
                time_sample = time_itrs.get('sample', [0])[-1]

                logger.record_tabular('time-train', time_train)
                logger.record_tabular('time-eval', time_eval)
                logger.record_tabular('time-sample', time_sample)
                logger.record_tabular('time-total', time_total)
                logger.record_tabular('epoch', epoch)

                self.sampler.log_diagnostics()

                logger.dump_tabular(with_prefix=False)
                logger.pop_prefix()

            self.sampler.terminate()
Example #3
0
    def _train(self, env, policy, pool):
        """Perform RL training.

        Args:
            env (`rllab.Env`): Environment used for training
            policy (`Policy`): Policy used for training
            pool (`PoolBase`): Sample pool to add samples to
        """
        self._init_training()
        self.sampler.initialize(env, policy, pool)

        evaluation_env = deep_clone(env) if self._eval_n_episodes else None

        with tf_utils.get_default_session().as_default():
            gt.rename_root('RLAlgorithm')
            gt.reset()
            gt.set_def_unique(False)

            for epoch in gt.timed_for(
                    range(self._n_epochs + 1), save_itrs=True):
                logger.push_prefix('Epoch #%d | ' % epoch)

                for t in range(self._epoch_length):
                    self.sampler.sample()
                    if not self.sampler.batch_ready():
                        continue
                    gt.stamp('sample')

                    for i in range(self._n_train_repeat):
                        self._do_training(
                            iteration=t + epoch * self._epoch_length,
                            batch=self.sampler.random_batch())
                    gt.stamp('train')

                self._evaluate(policy, evaluation_env)
                gt.stamp('eval')

                params = self.get_snapshot(epoch)
                logger.save_itr_params(epoch, params)

                time_itrs = gt.get_times().stamps.itrs
                time_eval = time_itrs['eval'][-1]
                time_total = gt.get_times().total
                time_train = time_itrs.get('train', [0])[-1]
                time_sample = time_itrs.get('sample', [0])[-1]

                logger.record_tabular('time-train', time_train)
                logger.record_tabular('time-eval', time_eval)
                logger.record_tabular('time-sample', time_sample)
                logger.record_tabular('time-total', time_total)
                logger.record_tabular('epoch', epoch)

                self.sampler.log_diagnostics()

                logger.dump_tabular(with_prefix=False)
                logger.pop_prefix()
    def __init__(
            self,
            batch_size=64,
            n_epochs=1000,
            n_train_repeat=1,
            epoch_length=1000,
            min_pool_size=10000,
            max_path_length=1000,
            eval_n_episodes=10,
            eval_render=False,
            iter_callback=None
    ):
        """
        Args:
            batch_size (`int`): Size of the sample batch to be used
                for training.
            n_epochs (`int`): Number of epochs to run the training for.
            n_train_repeat (`int`): Number of times to repeat the training
                for single time step.
            epoch_length (`int`): Epoch length.
            min_pool_size (`int`): Minimum size of the sample pool before
                running training.
            max_path_length (`int`): Number of timesteps before resetting
                environment and policy, and the number of paths used for
                evaluation rollout.
            eval_n_episodes (`int`): Number of rollouts to evaluate.
            eval_render (`int`): Whether or not to render the evaluation
                environment.
            iter_callback (`Function(locals, globals)`): Callback function called before every epoch.
        """
        self._batch_size = batch_size
        self._n_epochs = n_epochs
        self._n_train_repeat = n_train_repeat
        self._epoch_length = epoch_length
        self._min_pool_size = min_pool_size
        self._max_path_length = max_path_length

        self._eval_n_episodes = eval_n_episodes
        self._eval_render = eval_render

        self._sess = tf_utils.get_default_session()

        self.env = None
        self.policy = None
        self.pool = None

        self.iter_callback = iter_callback
    def __init__(self,
                 batch_size=64,
                 n_epochs=1000,
                 n_train_repeat=1,
                 epoch_length=1000,
                 min_pool_size=10000,
                 max_path_length=1000,
                 eval_n_episodes=10,
                 eval_render=False,
                 iter_callback=None):
        """
        Args:
            batch_size (`int`): Size of the sample batch to be used
                for training.
            n_epochs (`int`): Number of epochs to run the training for.
            n_train_repeat (`int`): Number of times to repeat the training
                for single time step.
            epoch_length (`int`): Epoch length.
            min_pool_size (`int`): Minimum size of the sample pool before
                running training.
            max_path_length (`int`): Number of timesteps before resetting
                environment and policy, and the number of paths used for
                evaluation rollout.
            eval_n_episodes (`int`): Number of rollouts to evaluate.
            eval_render (`int`): Whether or not to render the evaluation
                environment.
            iter_callback (`Function(locals, globals)`): Callback function called before every epoch.
        """
        self._batch_size = batch_size
        self._n_epochs = n_epochs
        self._n_train_repeat = n_train_repeat
        self._epoch_length = epoch_length
        self._min_pool_size = min_pool_size
        self._max_path_length = max_path_length

        self._eval_n_episodes = eval_n_episodes
        self._eval_render = eval_render

        self._sess = tf_utils.get_default_session()

        self.env = None
        self.policy = None
        self.pool = None

        self.iter_callback = iter_callback
Example #6
0
 def __init__(
         self,
         batch_size=64,
         n_epochs=1000,
         n_train_repeat=1,
         epoch_length=1000,
         min_pool_size=10000,
         max_path_length=1000,
         eval_n_episodes=20,
         eval_render=True,
 ):
     """
     Args:
         batch_size (`int`): Size of the sample batch to be used
             for training.
         n_epochs (`int`): Number of epochs to run the training for.
         n_train_repeat (`int`): Number of times to repeat the training
             for single time step.
         epoch_length (`int`): Epoch length.
         min_pool_size (`int`): Minimum size of the sample pool before
             running training.
         max_path_length (`int`): Number of timesteps before resetting
             environment and policy, and the number of paths used for
             evaluation rollout.
         eval_n_episodes (`int`): Number of rollouts to evaluate.
         eval_render (`int`): Whether or not to render the evaluation
             environment.
     """
     self._batch_size = batch_size
     self._n_epochs = n_epochs
     self._n_train_repeat = n_train_repeat
     self._epoch_length = epoch_length
     self._min_pool_size = min_pool_size
     self._max_path_length = max_path_length
     self._eval_n_episodes = eval_n_episodes
     self._eval_render = eval_render
     self._sess = tf_utils.get_default_session()
     self.env = None
     self.policy = None
     self.pool = None
     self.reward = np.zeros([self._n_epochs, 4])
Example #7
0
 def _eval(self, inputs):
     feeds = dict(zip(self._inputs, inputs))
     return tf_utils.get_default_session().run(self._output, feeds)
Example #8
0
    def __init__(self,
                 env,
                 policy,
                 q_function,
                 replay_buffer,
                 sampler,
                 discount=0.99,
                 epoch_length=1000,
                 eval_n_episodes=10,
                 eval_render=False,
                 kernel_fn=adaptive_isotropic_gaussian_kernel,
                 kernel_n_particles=16,
                 kernel_update_ratio=0.5,
                 n_epochs=1000,
                 n_train_repeat=1,
                 plotter=None,
                 policy_lr=1e-3,
                 q_function_lr=1e-3,
                 reward_scale=1,
                 save_full_state=False,
                 sess=None,
                 td_target_update_interval=1,
                 train_policy=True,
                 train_qf=True,
                 use_saved_policy=False,
                 use_saved_qf=False,
                 value_n_particles=16):
        super(SQL, self).__init__(**dict(sampler=sampler,
                                         epoch_length=epoch_length,
                                         eval_n_episodes=eval_n_episodes,
                                         eval_render=eval_render,
                                         n_epochs=n_epochs,
                                         n_train_repeat=n_train_repeat))
        self.env = env
        self.plotter = plotter
        self.policy = policy
        self.q_function = q_function
        self.replay_buffer = replay_buffer
        self._action_dim = self.env.action_space.n
        self._discount = discount
        self._kernel_fn = kernel_fn
        self._kernel_n_particles = kernel_n_particles
        self._kernel_update_ratio = kernel_update_ratio
        self._observation_dim = np.prod(self.env.observation_space.shape)
        self._policy_lr = policy_lr
        self._q_function_lr = q_function_lr
        self._qf_target_update_interval = td_target_update_interval
        self._reward_scale = reward_scale
        self._save_full_state = save_full_state
        self._sess = sess if sess else tf_utils.get_default_session()
        self._target_ops = []
        self._train_policy = train_policy
        self._train_qf = train_qf
        self._training_ops = []
        self._value_n_particles = value_n_particles

        self._create_placeholders()
        self._create_svgd_update()
        self._create_target_ops()
        self._create_td_update()

        self._sess.run(tf.global_variables_initializer())
        # self._sess.run(tf.local_variables_initializer())
        print(self._sess.run(tf.report_uninitialized_variables()))
        # self._sess.graph.finalize()

        if use_saved_qf:
            saved_qf_params = q_function.get_param_values()
            self.q_function.set_param_values(saved_qf_params)
        if use_saved_policy:
            saved_policy_params = policy.get_param_values()
            self.policy.set_param_values(saved_policy_params)
Example #9
0
    def __init__(
        self,
        base_kwargs,
        env,
        pool,
        qf,
        policy,
        plotter=None,
        policy_lr=1E-3,
        qf_lr=1E-3,
        value_n_particles=16,
        td_target_update_interval=1,
        kernel_fn=adaptive_isotropic_gaussian_kernel,
        kernel_n_particles=16,
        kernel_update_ratio=0.5,
        discount=0.99,
        reward_scale=1,
        use_saved_qf=False,
        use_saved_policy=False,
        save_full_state=False,
        train_qf=True,
        train_policy=True,
    ):
        """
        Args:
            base_kwargs (dict): Dictionary of base arguments that are directly
                passed to the base `RLAlgorithm` constructor.
            env (`rllab.Env`): rllab environment object.
            pool (`PoolBase`): Replay buffer to add gathered samples to.
            qf (`NNQFunction`): Q-function approximator.
            policy: (`rllab.NNPolicy`): A policy function approximator.
            plotter (`QFPolicyPlotter`): Plotter instance to be used for
                visualizing Q-function during training.
            qf_lr (`float`): Learning rate used for the Q-function approximator.
            value_n_particles (`int`): The number of action samples used for
                estimating the value of next state.
            td_target_update_interval (`int`): How often the target network is
                updated to match the current Q-function.
            kernel_fn (function object): A function object that represents
                a kernel function.
            kernel_n_particles (`int`): Total number of particles per state
                used in SVGD updates.
            kernel_update_ratio ('float'): The ratio of SVGD particles used for
                the computation of the inner/outer empirical expectation.
            discount ('float'): Discount factor.
            reward_scale ('float'): A factor that scales the raw rewards.
                Useful for adjusting the temperature of the optimal Boltzmann
                distribution.
            use_saved_qf ('boolean'): If true, use the initial parameters provided
                in the Q-function instead of reinitializing.
            use_saved_policy ('boolean'): If true, use the initial parameters provided
                in the policy instead of reinitializing.
            save_full_state ('boolean'): If true, saves the full algorithm
                state, including the replay buffer.
        """
        super(SQL, self).__init__(**base_kwargs)

        self.env = env
        self.pool = pool
        self.qf = qf
        self.policy = policy
        self.plotter = plotter

        self._qf_lr = qf_lr
        self._policy_lr = policy_lr
        self._discount = discount
        self._reward_scale = reward_scale

        self._value_n_particles = value_n_particles
        self._qf_target_update_interval = td_target_update_interval

        self._kernel_fn = kernel_fn
        self._kernel_n_particles = kernel_n_particles
        self._kernel_update_ratio = kernel_update_ratio

        self._save_full_state = save_full_state
        self._train_qf = train_qf
        self._train_policy = train_policy

        self._observation_dim = flat_dim(self.env.observation_space)
        self._action_dim = flat_dim(self.env.action_space)

        self._create_placeholders()

        self._training_ops = []
        self._target_ops = []

        self._create_td_update()
        self._create_svgd_update()
        self._create_target_ops()

        if use_saved_qf:
            saved_qf_params = qf.get_param_values()
        if use_saved_policy:
            saved_policy_params = policy.get_param_values()

        self._sess = tf_utils.get_default_session()
        self._sess.run(tf.global_variables_initializer())

        if use_saved_qf:
            self.qf.set_param_values(saved_qf_params)
        if use_saved_policy:
            self.policy.set_param_values(saved_policy_params)
Example #10
0
    def _train_brownian(self, env, policies, pool):
        """Perform RL training.

        Args:
            env (`rllab.Env`): Environment used for training
            policy (`Policy`): Policy used for training
            pool (`PoolBase`): Sample pool to add samples to
        """

        ## Initilizing optimization
        self._init_training()

        # Initialize the sampler
        self.sampler.initialize(env, policies, pool)

        # For saving diagnostics
        self.episode_rewards = []
        self.episode_lengths = []

        # 1. Sample Nearby <-- brownian_agent.py
        # - 현재 starts에서 랜덤하게 한개의 start_state를 선택
        # - init_state에서 시작하여 brownian motion을 하며 지나는 점들을 starts에 넣는다
        # - hide rollout, params['start_pool_size']의 크기만큼 starts sample
        logger.log('Re-sampling new start positions ...')
        policies['hide'].sample_nearby(animated=False)
        logger.log('%d new goals populated' % len(policies['hide'].starts))

        if env.spec.id[:6] == 'Blocks':
            start_scale = 2.4  #Works for maze1_singlegoal only, so be careful
            xlim = [-1.0, 1.0]
            ylim = [-1.0, 1.0]
        else:
            start_scale = 1.0
            xlim = [-0.22, 0.22]
            ylim = [-0.22, 0.22]

        variance_mean = policies['hide'].action_variance_default
        variance = policies['hide'].action_variance_default

        with tf_utils.get_default_session().as_default():
            total_episode_length = 0
            for itr in range(self._n_itr):
                logger.push_prefix('itr #%d | ' % itr)
                logger.log("Sample a path")  ## should be batch_size
                x_init_eplst = []
                y_init_eplst = []
                path_lengths = []
                start_state_taskclassif_labels = []
                for n in range(20):
                    for t in range(self.max_path_length):
                        # 2. Obtain Samples
                        # <-- sample_sql(sampler.py)
                        # starts에서 init_state와 goal을 sample하여 초기화
                        # seek rollout, 20개의 path sample
                        # - x_init_eplst, y_init_eplst에는 하나의 path의 시작점(normalized)이 들어있다.
                        # - plot 'xy_time' : path의 길이(시간)가 t => 길수록 빨간색
                        # - plot 'xy_tasklabels' : path의 길이(시간)가 timelen_max를 넘는지 아닌지가 r
                        start_pose, done, path_length = self.sampler.sample_sql(
                            animated=False)  ## add reward shaping and plotting
                        if done:
                            print("---------------------->done: path_length:",
                                  path_length)
                            x_init_eplst.append(start_pose[0][0] / 2.4)
                            y_init_eplst.append(start_pose[0][1] / 2.4)
                            path_lengths.append(path_length)
                            start_state_taskclassif_labels.append(
                                int(path_length < (self.timelen_max - 1)))
                            total_episode_length += path_length
                            break
                        if not self.sampler.batch_ready():
                            continue
                        else:
                            print("############################### ready")
                        for i in range(self._n_train_repeat):
                            print("---------------------->training")
                            self._do_training(
                                iteration=t + total_episode_length,
                                batch=self.sampler.random_batch())

                # normalized된 path의 시작점이 x,y값 path의 길이(시간)가 t이다.
                # path의 길이가 길면 red, 짧으면 blue
                # print("+++++++++++++++++ plot ++++++++++++++++++++")
                # print('x:', x_init_eplst)
                # print('y:', y_init_eplst)
                # print('xy_time:', path_lengths)
                # print('xy_reward:', start_state_taskclassif_labels)
                # print("+++++++++++++++++++++++++++++++++++++++++++")
                self.myplotter.plot_xy_time(x=x_init_eplst,
                                            y=y_init_eplst,
                                            t=path_lengths,
                                            t_max=self.timelen_max,
                                            img_name='xy_time_itr' + str(itr),
                                            xlim=xlim,
                                            ylim=ylim)

                # normalized된 path의 시작점이 x,y값 path의 길이(시간)가 timelen_max를 넘는지 아닌지가 r이다.
                # 만약 path의 길이가 self.timelen_max보다 작을 경우 red(1), 그렇지 않으면 blue(0)
                self.myplotter.plot_xy_reward(x=x_init_eplst,
                                              y=y_init_eplst,
                                              r=start_state_taskclassif_labels,
                                              img_name='xy_tasklabels_itr' +
                                              str(itr),
                                              name='xy_tasklabels',
                                              r_min=0.,
                                              r_max=1.,
                                              xlim=xlim,
                                              ylim=ylim)

                update_now = (itr % (self.starts_update_every_itr * 10) == 0)
                update_period = self.starts_update_every_itr

                # Saving training samples for variance regression (prediction)
                logger.record_tabular('hide_starts_update_period',
                                      update_period)
                logger.record_tabular('hide_starts_update_period_max',
                                      self.starts_update_every_itr)

                if update_now:
                    self.myplotter.plot_goal_rewards(
                        goals=policies['hide'].starts,
                        rewards=policies['hide'].rewards,
                        img_name='goal_rewards_itr%03d' % itr,
                        scale=start_scale,
                        clear=True,
                        env=env)

                    logger.log('Filtering start positions ...')
                    policies['hide'].select_starts(
                        success_rate=self.center_reached_ratio)
                    print("*************** hide_starts ****************")
                    print(policies['hide'].starts)
                    print("*******************************************")
                    logger.log('%d goals selected' %
                               len(policies['hide'].starts))
                    self.myplotter.plot_goals(goals=policies['hide'].starts,
                                              color=[0, 1, 0],
                                              clear=True,
                                              env=env)

                    logger.log('Re-sampling new start positions ...')

                    # Update Variance
                    if self.brown_adaptive_variance == 4:
                        variance_diff = self.brown_var_control_coeff * (
                            self.center_reached_ratio -
                            self.center_reached_ratio_max)
                        variance_diff = np.clip(variance_diff, -0.5, 0.5)
                        logger.log('brown: variance change %f' % variance_diff)
                        variance_mean += variance_diff
                        variance_mean = np.clip(variance_mean,
                                                a_min=self.brown_var_min,
                                                a_max=1.0)
                    else:
                        variance_mean = self.policies[
                            'hide'].action_variance_default  # using default variance provided in the config

                    variance = copy.deepcopy(variance_mean)
                    logger.log('Adaptive Variance | r_avg: %f' %
                               self.center_reached_ratio)
                    logger.log('Adaptive Variance | variance_mean: [%f, %f]' %
                               (variance[0], variance[1]))

                    # print('!!!!!!!!!!!!!!!+++ brown: variance mean: ', variance, 'dtype', type(variance_mean))
                    policies['hide'].sample_nearby(
                        itr=itr,
                        success_rate=self.center_reached_ratio,
                        variance=variance,
                        animated=False)
                    logger.log('Re-sampled %d new goals %d old goals' %
                               (len(policies['hide'].starts),
                                len(policies['hide'].starts_old)))
                    self.myplotter.plot_goals(goals=policies['hide'].starts,
                                              color=[1, 0, 0],
                                              scale=start_scale,
                                              env=env)
                    self.myplotter.plot_goals(
                        goals=policies['hide'].starts_old,
                        color=[0, 0, 1],
                        scale=start_scale,
                        img_name='goals',
                        env=env)

                # Test environment
                # self._eval_n_episodes만큼 rollout
                paths = self._evaluate(policies, self.env_test)
                x_init_eplst = [path['observations'][0][0] for path in paths]
                y_init_eplst = [path['observations'][0][1] for path in paths]
                path_lengths = [path['rewards'].size for path in paths]

                # env_test를 이용할때,,
                # normalized된 path의 시작점이 x,y값 path의 길이(시간)가 t이다.
                # path의 길이가 길면 red, 짧으면 blue
                self.myplotter.plot_xy_time(x=x_init_eplst,
                                            y=y_init_eplst,
                                            t=path_lengths,
                                            t_max=self.timelen_max,
                                            img_name='xy_time_test_itr' +
                                            str(itr),
                                            name='xy_time_test',
                                            xlim=xlim,
                                            ylim=ylim)

                params = self.get_snapshot(itr)
                logger.save_itr_params(itr, params)

                # time_itrs = gt.get_times().stamps.itrs
                # time_eval = time_itrs['eval'][-1]
                # time_total = gt.get_times().total
                # time_train = time_itrs.get('train', [0])[-1]
                # time_sample = time_itrs.get('sample', [0])[-1]
                #
                # logger.record_tabular('time-train', time_train)
                # logger.record_tabular('time-eval', time_eval)
                # logger.record_tabular('time-sample', time_sample)
                # logger.record_tabular('time-total', time_total)
                logger.record_tabular('itr', itr)

                self.sampler.log_diagnostics()

                logger.dump_tabular(with_prefix=False)
                logger.pop_prefix()

            self.sampler.terminate()
Example #11
0
def load_buffer_and_qf(filename):
    with tf_utils.get_default_session().as_default():
        data = joblib.load(os.path.join(PROJECT_PATH, filename))

    return data['replay_buffer'], data['qf']
Example #12
0
def load_buffer_and_qf(filename):
    with tf_utils.get_default_session().as_default():
        data = joblib.load(os.path.join(PROJECT_PATH, filename))

    return data['replay_buffer'], data['qf']
Example #13
0
    def _eval(self, *inputs):
        feeds = {pl: val for pl, val in zip(self._inputs, inputs)}

        return tf_utils.get_default_session().run(self._output, feeds)
    def _train(self, env, policy, pool, load=None):
        """Perform RL training.

        Args:
            env (`rllab.Env`): Environment used for training
            policy (`Policy`): Policy used for training
            pool (`PoolBase`): Sample pool to add samples to
        """
        self._init_training()
        self.sampler.initialize(env, policy, pool)

        # evaluation_env = deep_clone(env) if self._eval_n_episodes else None

        with tf_utils.get_default_session().as_default() as sess:
            if load is not None:
                saver = tf.train.Saver()
                saver.restore(sess, load)
                print('pre-trained model restored ...')
            gt.rename_root('RLAlgorithm')
            gt.reset()
            gt.set_def_unique(False)

            for epoch in gt.timed_for(range(self._n_epochs + 1),
                                      save_itrs=True):
                logger.push_prefix('Epoch #%d | ' % epoch)

                for t in range(self._epoch_length):
                    self.sampler.sample()
                    if not self.sampler.batch_ready():
                        continue
                    gt.stamp('sample')

                    for i in range(self._n_train_repeat):
                        self._do_training(iteration=t +
                                          epoch * self._epoch_length,
                                          batch=self.sampler.random_batch())
                    gt.stamp('train')

                if epoch % 1 == 0 or epoch >= ENV_PARAMS['n_epochs'] - 20:
                    self._evaluate(policy, env)
                    print('@ epoch %d : ' % epoch)
                    # gt.stamp('eval')
                    #
                    # params = self.get_snapshot(epoch)
                    # logger.save_itr_params(epoch, params)
                    #
                    # time_itrs = gt.get_times().stamps.itrs
                    # time_eval = time_itrs['eval'][-1]
                    # time_total = gt.get_times().total
                    # time_train = time_itrs.get('train', [0])[-1]
                    # time_sample = time_itrs.get('sample', [0])[-1]
                    #
                    # logger.record_tabular('time-train', time_train)
                    # logger.record_tabular('time-eval', time_eval)
                    # logger.record_tabular('time-sample', time_sample)
                    # logger.record_tabular('time-total', time_total)
                    # logger.record_tabular('epoch', epoch)

                    self.sampler.log_diagnostics()

                    logger.dump_tabular(with_prefix=False)
                    logger.pop_prefix()

                    # env.reset()

                if (epoch > ENV_PARAMS['n_epochs'] * 0 and epoch % 5
                        == 0) or epoch >= ENV_PARAMS['n_epochs'] - 100:
                    saver = tf.train.Saver()
                    saver.save(sess,
                               save_path=save_path + '/model-' + str(epoch) +
                               '.ckpt')
                    print('Model saved ...')

            self.sampler.terminate()
Example #15
0
    def _eval(self, *inputs):
        feeds = {pl: val for pl, val in zip(self._inputs, inputs)}

        return tf_utils.get_default_session().run(self._output, feeds)
 def get_actions(self, observations):
     feeds = {self._obs_pl: observations}
     actions = tf_utils.get_default_session().run(self._action, feeds)
     return actions