예제 #1
0
    def run(self):
        self._set_io_size()
        self._build_session()

        timer_dict = OrderedDict()
        timer_dict['Program Start'] = time.time()
        rolling_stats = defaultdict(list)
        training_return = {}

        with self._session as sess:
            self._build_models()
            self._init_whitening_stats()

            data_dict = defaultdict(list)
            while True:
                if (self.timesteps_so_far % self.args.print_frequency) == 0:
                    timer_dict['** Program Total Time **'] = time.time()
                    training_return['stats'] = {}
                    for key in rolling_stats:
                        training_return['stats'][key] = np.mean(
                            rolling_stats[key][-self.args.print_frequency:])

                    if 'mean_rewards' in training_return['stats']:
                        training_return['stats']['mean_rewards'] *= \
                           self.args.episode_length

                    training_return['iteration'] = \
                        self.timesteps_so_far//self.args.print_frequency
                    training_return['totalsteps'] = self.timesteps_so_far

                    log_results(training_return, timer_dict)

                temp_dict = self._play()
                for key in temp_dict:
                    data_dict[key].append(temp_dict[key])

                if self.timesteps_so_far >= self.args.dqn_training_start and \
                    (self.timesteps_so_far % self.args.dqn_train_freq == 0):
                    stats, _ = self._train(data_dict)

                    for key in stats:
                        rolling_stats[key].append(stats[key])
                    # log and print the results

                    # clear data_dict
                    data_dict = defaultdict(list)

                if self.timesteps_so_far >= self.args.train_dqn_steps:
                    self._saver.save(
                        self._session,
                        osp.join(
                            logger._get_path(),
                            "pretrained_model_{}".format(self._name_scope)))
                    break
                else:
                    self.timesteps_so_far += 1

            final_weights = self._get_weights()

        return final_weights, self._environments_cache
예제 #2
0
    def run(self):
        self._set_io_size()
        self._build_models()
        self._init_replay_buffer()
        self._init_whitening_stats()

        # load the model if needed
        if self.args.ckpt_name is not None:
            self._restore_all()

        # the main training process
        while True:
            next_task = self._task_queue.get()

            print(next_task)

            if next_task[0] is None or next_task[0] == parallel_util.END_SIGNAL:
                # kill the learner
                self._task_queue.task_done()
                break

            elif next_task[0] == parallel_util.START_SIGNAL:
                # get network weights
                self._task_queue.task_done()
                self._result_queue.put(self._get_weights())

            elif next_task[0] == parallel_util.RESET_SIGNAL:
                self._task_queue.task_done()
                self._init_whitening_stats()
                self._timesteps_so_far = 0
                self._iteration = 0

            elif next_task[0] == parallel_util.SAVE_SIGNAL:
                _save_root = next_task[1]['net']
                _log_path = logger._get_path()

                _save_extension = _save_root + \
                    "_{}_{}.ckpt".format(
                        self._name_scope, self._timesteps_so_far
                    )

                _save_dir = osp.join(_log_path, _save_extension)
                self._saver.save(self._session, _save_dir)

            else:
                # training
                assert next_task[0] == parallel_util.TRAIN_SIGNAL
                stats = self._update_parameters(next_task[1]['data'],
                                                next_task[1]['training_info'])
                self._task_queue.task_done()

                self._iteration += 1
                return_data = {
                    'network_weights': self._get_weights(),
                    'stats': stats,
                    'totalsteps': self._timesteps_so_far,
                    'iteration': self._iteration
                }
                self._result_queue.put(return_data)
예제 #3
0
    def __init__(self, env_name, *args, **kwargs):
        remove_render = re.compile(r'__render$')

        self.env_name = remove_render.sub('', env_name)
        self.env, _ = env.env_register.make_env(self.env_name, *args, **kwargs)
        self.episode_number = 0

        # Getting path from logger
        self.path = logger._get_path()
        self.obs_buffer = []
        self.always_render = False
        self.render_name = ''
예제 #4
0
    def run(self):
        self._set_io_size()
        self._build_session()

        timer_dict = OrderedDict()
        timer_dict['Program Start'] = time.time()

        self._build_environments()

        with self._session as sess:
            self._build_models()
            if self.weights is not None:
                self._set_weights(self.weights)
            self._network['explore'].set_environments(self._environments_cache)
            self._network['explore'].action_size = self._action_size
            buffer = self._generate_trajectories()
            self._init_whitening_stats()

            while True:
                data_dict = replay_buffer.sample(
                    self.args.transfer_sample_traj)
                episode_length = self.args.lookahead_increment

                training_info = {'episode_length': episode_length}

                stats, _ = self.models['transfer'].train(
                    data_dict, buffer, training_info)

                timer_dict['** Program Total Time **'] = time.time()
                log_results(stats, timer_dict)

                if self.current_iteration > self.args.transfer_iterations:
                    _log_path = logger._get_path()
                    _save_root = 'transfer'
                    _save_extension = _save_root + \
                      "_{}_{}.ckpt".format(
                          self._name_scope, self._timesteps_so_far
                      )

                    _save_dir = osp.join(_log_path, _save_extension)
                    self._saver.save(self._session, _save_dir)
                    break

                else:
                    self.current_iteration += 1
        return self._get_weights()