Ejemplo n.º 1
0
    def train(self):
        plotter = Plotter()
        if self.plot:
            plotter.init_plot(self.env, self.policy)
        self.start_worker()
        self.init_opt()
        for itr in range(self.current_itr, self.n_itr):
            with logger.prefix('itr #{} | '.format(itr)):
                paths = self.sampler.obtain_samples(itr)
                samples_data = self.sampler.process_samples(itr, paths)
                self.log_diagnostics(paths)
                self.optimize_policy(itr, samples_data)
                logger.log('Saving snapshot...')
                params = self.get_itr_snapshot(itr, samples_data)
                self.current_itr = itr + 1
                params['algo'] = self
                if self.store_paths:
                    params['paths'] = samples_data['paths']
                snapshotter.save_itr_params(itr, params)
                logger.log('saved')
                logger.log(tabular)
                if self.plot:
                    plotter.update_plot(self.policy, self.max_path_length)
                    if self.pause_for_plot:
                        input('Plotting evaluation run: Press Enter to '
                              'continue...')

        plotter.close()
        self.shutdown_worker()
Ejemplo n.º 2
0
    def train(self, sess=None):
        address = ("localhost", 6000)
        conn = Client(address)
        last_average_return = None
        try:
            created_session = True if (sess is None) else False
            if sess is None:
                sess = tf.Session()
                sess.__enter__()

            sess.run(tf.global_variables_initializer())
            conn.send(ExpLifecycle.START)
            self.start_worker(sess)
            start_time = time.time()
            for itr in range(self.start_itr, self.n_itr):
                itr_start_time = time.time()
                with logger.prefix('itr #%d | ' % itr):
                    logger.log("Obtaining samples...")
                    conn.send(ExpLifecycle.OBTAIN_SAMPLES)
                    paths = self.obtain_samples(itr)
                    logger.log("Processing samples...")
                    conn.send(ExpLifecycle.PROCESS_SAMPLES)
                    samples_data = self.process_samples(itr, paths)
                    last_average_return = samples_data["average_return"]
                    logger.log("Logging diagnostics...")
                    self.log_diagnostics(paths)
                    logger.log("Optimizing policy...")
                    conn.send(ExpLifecycle.OPTIMIZE_POLICY)
                    self.optimize_policy(itr, samples_data)
                    logger.log("Saving snapshot...")
                    params = self.get_itr_snapshot(itr)
                    if self.store_paths:
                        params["paths"] = samples_data["paths"]
                    snapshotter.save_itr_params(itr, params)
                    logger.log("Saved")
                    tabular.record('Time', time.time() - start_time)
                    tabular.record('ItrTime', time.time() - itr_start_time)
                    logger.log(tabular)
                    if self.plot:
                        conn.send(ExpLifecycle.UPDATE_PLOT)
                        self.plotter.update_plot(self.policy,
                                                 self.max_path_length)
                        if self.pause_for_plot:
                            input("Plotting evaluation run: Press Enter to "
                                  "continue...")

            conn.send(ExpLifecycle.SHUTDOWN)
            self.shutdown_worker()
            if created_session:
                sess.close()
        finally:
            conn.close()
        return last_average_return
Ejemplo n.º 3
0
    def save_snapshot(self, itr, paths=None):
        """Save snapshot of current batch.

        Args:
            itr: Index of iteration (epoch).
            paths: Batch of samples after preprocessed.

        """
        assert self.has_setup

        logger.log("Saving snapshot...")
        params = self.algo.get_itr_snapshot(itr)
        params['env'] = self.env
        if paths:
            params['paths'] = paths
        snapshotter.save_itr_params(itr, params)
        logger.log('Saved')
Ejemplo n.º 4
0
    def train(self):
        address = ('localhost', 6000)
        conn = Client(address)
        try:
            plotter = Plotter()
            if self.plot:
                plotter.init_plot(self.env, self.policy)
            conn.send(ExpLifecycle.START)
            self.start_worker()
            self.init_opt()
            for itr in range(self.current_itr, self.n_itr):
                with logger.prefix('itr #{} | '.format(itr)):
                    conn.send(ExpLifecycle.OBTAIN_SAMPLES)
                    paths = self.sampler.obtain_samples(itr)
                    conn.send(ExpLifecycle.PROCESS_SAMPLES)
                    samples_data = self.sampler.process_samples(itr, paths)
                    self.log_diagnostics(paths)
                    conn.send(ExpLifecycle.OPTIMIZE_POLICY)
                    self.optimize_policy(itr, samples_data)
                    logger.log('saving snapshot...')
                    params = self.get_itr_snapshot(itr, samples_data)
                    self.current_itr = itr + 1
                    params['algo'] = self
                    if self.store_paths:
                        params['paths'] = samples_data['paths']
                    snapshotter.save_itr_params(itr, params)
                    logger.log('saved')
                    logger.log(tabular)
                    if self.plot:
                        conn.send(ExpLifecycle.UPDATE_PLOT)
                        plotter.update_plot(self.policy, self.max_path_length)
                        if self.pause_for_plot:
                            input('Plotting evaluation run: Press Enter to '
                                  'continue...')

            conn.send(ExpLifecycle.SHUTDOWN)
            plotter.close()
            self.shutdown_worker()
        finally:
            conn.close()
Ejemplo n.º 5
0
    def train(self):

        cur_std = self.sigma0
        cur_mean = self.policy.get_param_values()
        es = cma.CMAEvolutionStrategy(cur_mean, cur_std)

        parallel_sampler.populate_task(self.env, self.policy)
        if self.plot:
            self.plotter.init_plot(self.env, self.policy)

        cur_std = self.sigma0
        cur_mean = self.policy.get_param_values()

        itr = 0
        while itr < self.n_itr and not es.stop():

            if self.batch_size is None:
                # Sample from multivariate normal distribution.
                xs = es.ask()
                xs = np.asarray(xs)
                # For each sample, do a rollout.
                infos = (stateful_pool.singleton_pool.run_map(
                    sample_return,
                    [(x, self.max_path_length, self.discount) for x in xs]))
            else:
                cum_len = 0
                infos = []
                xss = []
                done = False
                while not done:
                    sbs = stateful_pool.singleton_pool.n_parallel * 2
                    # Sample from multivariate normal distribution.
                    # You want to ask for sbs samples here.
                    xs = es.ask(sbs)
                    xs = np.asarray(xs)

                    xss.append(xs)
                    sinfos = stateful_pool.singleton_pool.run_map(
                        sample_return,
                        [(x, self.max_path_length, self.discount) for x in xs])
                    for info in sinfos:
                        infos.append(info)
                        cum_len += len(info['returns'])
                        if cum_len >= self.batch_size:
                            xs = np.concatenate(xss)
                            done = True
                            break

            # Evaluate fitness of samples (negative as it is minimization
            # problem).
            fs = -np.array([info['returns'][0] for info in infos])
            # When batching, you could have generated too many samples compared
            # to the actual evaluations. So we cut it off in this case.
            xs = xs[:len(fs)]
            # Update CMA-ES params based on sample fitness.
            es.tell(xs, fs)

            logger.push_prefix('itr #{} | '.format(itr))
            tabular.record('Iteration', itr)
            tabular.record('CurStdMean', np.mean(cur_std))
            undiscounted_returns = np.array(
                [info['undiscounted_return'] for info in infos])
            tabular.record('AverageReturn', np.mean(undiscounted_returns))
            tabular.record('StdReturn', np.mean(undiscounted_returns))
            tabular.record('MaxReturn', np.max(undiscounted_returns))
            tabular.record('MinReturn', np.min(undiscounted_returns))
            tabular.record('AverageDiscountedReturn', np.mean(fs))
            tabular.record('AvgTrajLen',
                           np.mean([len(info['returns']) for info in infos]))
            self.policy.log_diagnostics(infos)
            snapshotter.save_itr_params(
                itr, dict(
                    itr=itr,
                    policy=self.policy,
                    env=self.env,
                ))
            logger.log(tabular)
            if self.plot:
                self.plotter.update_plot(self.policy, self.max_path_length)
            logger.pop_prefix()
            # Update iteration.
            itr += 1

        # Set final params.
        self.policy.set_param_values(es.result()[0])
        parallel_sampler.terminate_task()
        self.plotter.close()
Ejemplo n.º 6
0
Archivo: cem.py Proyecto: psxz/garage
    def train(self):
        parallel_sampler.populate_task(self.env, self.policy)
        if self.plot:
            self.plotter.init_plot(self.env, self.policy)

        cur_std = self.init_std
        cur_mean = self.policy.get_param_values()
        # K = cur_mean.size
        n_best = max(1, int(self.n_samples * self.best_frac))

        for itr in range(self.n_itr):
            # sample around the current distribution
            extra_var_mult = max(1.0 - itr / self.extra_decay_time, 0)
            sample_std = np.sqrt(
                np.square(cur_std) +
                np.square(self.extra_std) * extra_var_mult)
            if self.batch_size is None:
                criterion = 'paths'
                threshold = self.n_samples
            else:
                criterion = 'samples'
                threshold = self.batch_size
            infos = stateful_pool.singleton_pool.run_collect(
                _worker_rollout_policy,
                threshold=threshold,
                args=(dict(
                    cur_mean=cur_mean,
                    sample_std=sample_std,
                    max_path_length=self.max_path_length,
                    discount=self.discount,
                    criterion=criterion,
                    n_evals=self.n_evals), ))
            xs = np.asarray([info[0] for info in infos])
            paths = [info[1] for info in infos]

            fs = np.array([path['returns'][0] for path in paths])
            print((xs.shape, fs.shape))
            best_inds = (-fs).argsort()[:n_best]
            best_xs = xs[best_inds]
            cur_mean = best_xs.mean(axis=0)
            cur_std = best_xs.std(axis=0)
            best_x = best_xs[0]
            logger.push_prefix('itr #{} | '.format(itr))
            tabular.record('Iteration', itr)
            tabular.record('CurStdMean', np.mean(cur_std))
            undiscounted_returns = np.array(
                [path['undiscounted_return'] for path in paths])
            tabular.record('AverageReturn', np.mean(undiscounted_returns))
            tabular.record('StdReturn', np.std(undiscounted_returns))
            tabular.record('MaxReturn', np.max(undiscounted_returns))
            tabular.record('MinReturn', np.min(undiscounted_returns))
            tabular.record('AverageDiscountedReturn', np.mean(fs))
            tabular.record('NumTrajs', len(paths))
            paths = list(chain(
                *[d['full_paths']
                  for d in paths]))  # flatten paths for the case n_evals > 1
            tabular.record('AvgTrajLen',
                           np.mean([len(path['returns']) for path in paths]))

            self.policy.set_param_values(best_x)
            self.policy.log_diagnostics(paths)
            snapshotter.save_itr_params(
                itr,
                dict(
                    itr=itr,
                    policy=self.policy,
                    env=self.env,
                    cur_mean=cur_mean,
                    cur_std=cur_std,
                ))
            logger.log(tabular)
            logger.pop_prefix()
            if self.plot:
                self.plotter.update_plot(self.policy, self.max_path_length)
        parallel_sampler.terminate_task()
        self.plotter.close()