示例#1
0
    def sample(self):
        batch_data = []
        for _ in range(self.batch_size):
            processed_obs = self.preprocessor.process_obs(self.obs)
            judge_is_nan([processed_obs])
            action, logp = self.policy_with_value.compute_action(processed_obs[np.newaxis, :])
            if self.explore_sigma is not None:
                action += np.random.normal(0, self.explore_sigma, np.shape(action))
            try:
                judge_is_nan([action])
            except ValueError:
                print('processed_obs', processed_obs)
                print('preprocessor_params', self.preprocessor.get_params())
                print('policy_weights', self.policy_with_value.policy.trainable_weights)
                action, logp = self.policy_with_value.compute_action(processed_obs[np.newaxis, :])
                judge_is_nan([action])
                raise ValueError
            obs_tp1, reward, self.done, info = self.env.step(action.numpy()[0])
            processed_rew = self.preprocessor.process_rew(reward, self.done)
            batch_data.append((self.obs.copy(), action.numpy()[0], reward, obs_tp1.copy(), self.done, info['ref_index']))
            self.obs = self.env.reset() if self.done else obs_tp1.copy()
            # self.env.render()

        if self.worker_id == 1 and self.sample_times % self.args.worker_log_interval == 0:
            logger.info('Worker_info: {}'.format(self.get_stats()))

        self.num_sample += len(batch_data)
        self.sample_times += 1
        return batch_data
    def step(self):
        self.optimizer_stats.update(
            dict(update_queue_size=self.inqueue.qsize(),
                 update_time=self.update_timer.mean,
                 update_throughput=self.update_timer.mean_throughput,
                 grad_queue_get_time=self.grad_queue_get_timer.mean,
                 grad_apply_timer=self.grad_apply_timer.mean))
        # fetch grad
        with self.grad_queue_get_timer:
            try:
                grads, learner_stats = self.inqueue.get(timeout=30)
            except Empty:
                return

        # apply grad
        with self.grad_apply_timer:
            try:
                judge_is_nan(grads)
            except ValueError:
                grads = [tf.zeros_like(grad) for grad in grads]
                logger.info('Grad is nan!, zero it')

            self.local_worker.apply_gradients(self.iteration, grads)

        # log
        if self.iteration % self.args.log_interval == 0:
            logger.info('updating {} in total'.format(self.iteration))
            logger.info('sampling {} in total'.format(
                self.optimizer_stats['num_sampled_steps']))
            with self.writer.as_default():
                for key, val in learner_stats.items():
                    if not isinstance(val, list):
                        tf.summary.scalar(
                            'optimizer/learner_stats/scalar/{}'.format(key),
                            val,
                            step=self.iteration)
                    else:
                        assert isinstance(val, list)
                        for i, v in enumerate(val):
                            tf.summary.scalar(
                                'optimizer/learner_stats/list/{}/{}'.format(
                                    key, i),
                                v,
                                step=self.iteration)
                for key, val in self.optimizer_stats.items():
                    tf.summary.scalar('optimizer/{}'.format(key),
                                      val,
                                      step=self.iteration)
                self.writer.flush()

        # evaluate
        if self.iteration % self.args.eval_interval == 0:
            self.evaluator.set_weights.remote(self.local_worker.get_weights())
            if self.args.obs_preprocess_type == 'normalize' or self.args.reward_preprocess_type == 'normalize':
                self.evaluator.set_ppc_params.remote(
                    self.local_worker.get_ppc_params())
            self.evaluator.run_evaluation.remote(self.iteration)

        # save
        if self.iteration % self.args.save_interval == 0:
            self.local_worker.save_weights(self.model_dir, self.iteration)
            self.workers['remote_workers'][0].save_ppc_params.remote(
                self.model_dir)

        self.iteration += 1
    def step(self):
        # sampling
        sampling_interval = 10
        if self.iteration % sampling_interval == 0:
            with self.timers['sampling_timer']:
                sample_batch, count = self.worker.sample_with_count()
                self.num_sampled_steps += count
                self.replay_buffer.add_batch(sample_batch)

        # replay
        with self.timers["replay_timer"]:
            samples = self.replay_buffer.replay()

        # learning
        with self.timers['learning_timer']:
            self.learner.set_weights(self.worker.get_weights())
            if self.args.obs_preprocess_type == 'normalize' or \
                    self.args.reward_preprocess_type == 'normalize':
                self.learner.set_ppc_params(self.worker.get_ppc_params())
            grads = self.learner.compute_gradient(samples[:-1],
                                                  self.replay_buffer,
                                                  samples[-1], self.iteration)
            learner_stats = self.learner.get_stats()
            if self.args.buffer_type == 'priority':
                info_for_buffer = self.learner.get_info_for_buffer()
                info_for_buffer['rb'].update_priorities(
                    info_for_buffer['indexes'], info_for_buffer['td_error'])

        # apply grad
        with self.timers['grad_apply_timer']:
            try:
                judge_is_nan(grads)
            except ValueError:
                grads = [tf.zeros_like(grad) for grad in grads]
                logger.info('Grad is nan!, zero it')
            self.worker.apply_gradients(self.iteration, grads)

        # log
        if self.iteration % self.args.log_interval == 0:
            logger.info('updating {} in total'.format(self.iteration))
            logger.info('sampling {} in total'.format(
                self.stats['num_sampled_steps']))
            with self.writer.as_default():
                for key, val in learner_stats.items():
                    if not isinstance(val, list):
                        tf.summary.scalar(
                            'optimizer/learner_stats/scalar/{}'.format(key),
                            val,
                            step=self.iteration)
                    else:
                        assert isinstance(val, list)
                        for i, v in enumerate(val):
                            tf.summary.scalar(
                                'optimizer/learner_stats/list/{}/{}'.format(
                                    key, i),
                                v,
                                step=self.iteration)
                for key, val in self.stats.items():
                    tf.summary.scalar('optimizer/{}'.format(key),
                                      val,
                                      step=self.iteration)
                self.writer.flush()

        # evaluate
        if self.iteration % self.args.eval_interval == 0 and self.evaluator is not None:
            self.evaluator.set_weights(self.worker.get_weights())
            self.evaluator.set_ppc_params(self.worker.get_ppc_params())
            self.evaluator.run_evaluation(self.iteration)

        # save
        if self.iteration % self.args.save_interval == 0:
            self.worker.save_weights(self.model_dir, self.iteration)
            self.worker.save_ppc_params(self.model_dir)

        self.get_stats()
        self.iteration += 1
    def step(self):
        assert self.update_thread.is_alive()
        assert len(self.workers['remote_workers']) > 0
        weights = None
        ppc_params = None

        # sampling
        with self.timers['sampling_timer']:
            for worker, objID in self.sample_tasks.completed():
                sample_batch, count = ray.get(objID)
                random.choice(
                    self.replay_buffers).add_batch.remote(sample_batch)
                self.num_sampled_steps += count
                self.steps_since_update[worker] += count
                ppc_params = worker.get_ppc_params.remote()
                if self.steps_since_update[
                        worker] >= self.max_weight_sync_delay:
                    judge_is_nan(self.local_worker.policy_with_value.policy.
                                 trainable_weights)
                    if weights is None:
                        weights = ray.put(self.local_worker.get_weights())
                    worker.set_weights.remote(weights)
                    self.steps_since_update[worker] = 0
                self.sample_tasks.add(worker,
                                      worker.sample_with_count.remote())

        # replay
        with self.timers["replay_timer"]:
            for rb, replay in self.replay_tasks.completed():
                self.replay_tasks.add(rb, rb.replay.remote())
                if self.learner_queue.full():
                    self.num_samples_dropped += 1
                else:
                    samples = ray.get(replay)
                    self.learner_queue.put((rb, samples))

        # learning
        with self.timers['learning_timer']:
            for learner, objID in self.learn_tasks.completed():
                grads = ray.get(objID)
                learner_stats = ray.get(learner.get_stats.remote())
                if self.args.buffer_type == 'priority':
                    info_for_buffer = ray.get(
                        learner.get_info_for_buffer.remote())
                    info_for_buffer['rb'].update_priorities.remote(
                        info_for_buffer['indexes'],
                        info_for_buffer['td_error'])
                rb, samples = self.learner_queue.get(block=False)
                if ppc_params and \
                        (self.args.obs_preprocess_type == 'normalize' or self.args.reward_preprocess_type == 'normalize'):
                    learner.set_ppc_params.remote(ppc_params)
                    self.local_worker.set_ppc_params(ppc_params)
                if weights is None:
                    weights = ray.put(self.local_worker.get_weights())
                learner.set_weights.remote(weights)
                self.learn_tasks.add(
                    learner,
                    learner.compute_gradient.remote(
                        samples[:-1], rb, samples[-1],
                        self.local_worker.iteration))
                if self.update_thread.inqueue.full():
                    self.num_grads_dropped += 1
                self.update_thread.inqueue.put([grads, learner_stats])

        self.iteration = self.update_thread.iteration
        self.optimizer_steps += 1
        self.get_stats()