def train(self):
        self.start_worker()
        self.init_opt()
        for itr in range(self.current_itr, self.n_itr):
            with logger.prefix('itr #%d | ' % itr):
                paths = self.sampler.obtain_samples(itr)
                samples_data = self.sampler.process_samples(itr, paths)
                self.log_diagnostics(paths)
                self.optimize_policy(itr, samples_data)
                logger.log("saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)
                self.current_itr = itr + 1
                params["algo"] = self
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("saved")
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                                  "continue...")

        self.shutdown_worker()
Exemplo n.º 2
0
    def train(self):
        # This seems like a rather sequential method
        pool = SimpleReplayPool(
            max_pool_size=self.replay_pool_size,
            observation_dim=self.env.observation_space.flat_dim,
            action_dim=self.env.action_space.flat_dim,
        )
        self.start_worker()

        self.init_opt()
        itr = 0
        path_length = 0
        path_return = 0
        terminal = False
        observation = self.env.reset()

        sample_policy = pickle.loads(pickle.dumps(self.policy))

        for epoch in range(self.n_epochs):
            logger.push_prefix('epoch #%d | ' % epoch)
            logger.log("Training started")
            for epoch_itr in pyprind.prog_bar(range(self.epoch_length)):
                # Execute policy
                if terminal:  # or path_length > self.max_path_length:
                    # Note that if the last time step ends an episode, the very
                    # last state and observation will be ignored and not added
                    # to the replay pool
                    observation = self.env.reset()
                    self.es.reset()
                    sample_policy.reset()
                    self.es_path_returns.append(path_return)
                    path_length = 0
                    path_return = 0
                action = self.es.get_action(itr,
                                            observation,
                                            policy=sample_policy)  # qf=qf)

                next_observation, reward, terminal, _ = self.env.step(action)
                path_length += 1
                path_return += reward

                if not terminal and path_length >= self.max_path_length:
                    terminal = True
                    # only include the terminal transition in this case if the flag was set
                    if self.include_horizon_terminal_transitions:
                        pool.add_sample(observation, action,
                                        reward * self.scale_reward, terminal)
                else:
                    pool.add_sample(observation, action,
                                    reward * self.scale_reward, terminal)

                observation = next_observation

                if pool.size >= self.min_pool_size:
                    for update_itr in range(self.n_updates_per_sample):
                        # Train policy
                        batch = pool.random_batch(self.batch_size)
                        self.do_training(itr, batch)
                    sample_policy.set_param_values(
                        self.policy.get_param_values())

                itr += 1

            logger.log("Training finished")
            if pool.size >= self.min_pool_size:
                self.evaluate(epoch, pool)
                params = self.get_epoch_snapshot(epoch)
                logger.save_itr_params(epoch, params)
            logger.dump_tabular(with_prefix=False)
            logger.pop_prefix()
            if self.plot:
                self.update_plot()
                if self.pause_for_plot:
                    input("Plotting evaluation run: Press Enter to "
                          "continue...")
        self.env.terminate()
        self.policy.terminate()
Exemplo n.º 3
0
    def train(self):

        cur_std = self.sigma0
        cur_mean = self.policy.get_param_values()
        es = cma_es_lib.CMAEvolutionStrategy(cur_mean, cur_std)

        parallel_sampler.populate_task(self.env, self.policy)
        if self.plot:
            plotter.init_plot(self.env, self.policy)

        cur_std = self.sigma0
        cur_mean = self.policy.get_param_values()

        itr = 0
        while itr < self.n_itr and not es.stop():

            if self.batch_size is None:
                # Sample from multivariate normal distribution.
                xs = es.ask()
                xs = np.asarray(xs)
                # For each sample, do a rollout.
                infos = (stateful_pool.singleton_pool.run_map(
                    sample_return,
                    [(x, self.max_path_length, self.discount) for x in xs]))
            else:
                cum_len = 0
                infos = []
                xss = []
                done = False
                while not done:
                    sbs = stateful_pool.singleton_pool.n_parallel * 2
                    # Sample from multivariate normal distribution.
                    # You want to ask for sbs samples here.
                    xs = es.ask(sbs)
                    xs = np.asarray(xs)

                    xss.append(xs)
                    sinfos = stateful_pool.singleton_pool.run_map(
                        sample_return,
                        [(x, self.max_path_length, self.discount) for x in xs])
                    for info in sinfos:
                        infos.append(info)
                        cum_len += len(info['returns'])
                        if cum_len >= self.batch_size:
                            xs = np.concatenate(xss)
                            done = True
                            break

            # Evaluate fitness of samples (negative as it is minimization
            # problem).
            fs = -np.array([info['returns'][0] for info in infos])
            # When batching, you could have generated too many samples compared
            # to the actual evaluations. So we cut it off in this case.
            xs = xs[:len(fs)]
            # Update CMA-ES params based on sample fitness.
            es.tell(xs, fs)

            logger.push_prefix('itr #%d | ' % itr)
            logger.record_tabular('Iteration', itr)
            logger.record_tabular('CurStdMean', np.mean(cur_std))
            undiscounted_returns = np.array(
                [info['undiscounted_return'] for info in infos])
            logger.record_tabular('AverageReturn',
                                  np.mean(undiscounted_returns))
            logger.record_tabular('StdReturn', np.mean(undiscounted_returns))
            logger.record_tabular('MaxReturn', np.max(undiscounted_returns))
            logger.record_tabular('MinReturn', np.min(undiscounted_returns))
            logger.record_tabular('AverageDiscountedReturn', np.mean(fs))
            logger.record_tabular(
                'AvgTrajLen',
                np.mean([len(info['returns']) for info in infos]))
            self.env.log_diagnostics(infos)
            self.policy.log_diagnostics(infos)

            logger.save_itr_params(
                itr, dict(
                    itr=itr,
                    policy=self.policy,
                    env=self.env,
                ))
            logger.dump_tabular(with_prefix=False)
            if self.plot:
                plotter.update_plot(self.policy, self.max_path_length)
            logger.pop_prefix()
            # Update iteration.
            itr += 1

        # Set final params.
        self.policy.set_param_values(es.result()[0])
        parallel_sampler.terminate_task()
    def train(self):
        with tf.Session() as sess:
            if self.load_policy is not None:
                import joblib
                self.policy = joblib.load(self.load_policy)['policy']
            self.init_opt()
            # initialize uninitialized vars (I know, it's ugly)
            uninit_vars = []
            for var in tf.global_variables():
                try:
                    sess.run(var)
                except tf.errors.FailedPreconditionError:
                    uninit_vars.append(var)
            sess.run(tf.variables_initializer(uninit_vars))
            #sess.run(tf.initialize_all_variables())
            self.start_worker()
            start_time = time.time()
            for itr in range(self.start_itr, self.n_itr):
                itr_start_time = time.time()
                with logger.prefix('itr #%d | ' % itr):

                    logger.log("Obtaining samples...")
                    paths = self.obtain_samples(itr)
                    logger.log("Processing samples...")
                    samples_data = self.process_samples(itr, paths)
                    logger.log("Logging diagnostics...")
                    self.log_diagnostics(paths)
                    logger.log("Optimizing policy...")
                    self.optimize_policy(itr, samples_data)
                    #new_param_values = self.policy.get_variable_values(self.policy.all_params)

                    logger.log("Saving snapshot...")
                    params = self.get_itr_snapshot(itr,
                                                   samples_data)  # , **kwargs)
                    if self.store_paths:
                        params["paths"] = samples_data["paths"]
                    logger.save_itr_params(itr, params)
                    logger.log("Saved")
                    logger.record_tabular('Time', time.time() - start_time)
                    logger.record_tabular('ItrTime',
                                          time.time() - itr_start_time)

                    #import pickle
                    #with open('paths_itr'+str(itr)+'.pkl', 'wb') as f:
                    #    pickle.dump(paths, f)

                    # debugging
                    """
                    if itr % 1 == 0:
                        logger.log("Saving visualization of paths")
                        import matplotlib.pyplot as plt;
                        for ind in range(5):
                            plt.clf(); plt.hold(True)
                            points = paths[ind]['observations']
                            plt.plot(points[:,0], points[:,1], '-r', linewidth=2)
                            plt.xlim([-1.0, 1.0])
                            plt.ylim([-1.0, 1.0])
                            plt.legend(['path'])
                            plt.savefig('/home/cfinn/path'+str(ind)+'.png')
                    """
                    # end debugging

                    logger.dump_tabular(with_prefix=False)
                    if self.plot:
                        self.update_plot()
                        if self.pause_for_plot:
                            input("Plotting evaluation run: Press Enter to "
                                  "continue...")
        self.shutdown_worker()
    def train(self):
        # TODO - make this a util
        flatten_list = lambda l: [item for sublist in l for item in sublist]

        with tf.Session() as sess:
            # Code for loading a previous policy. Somewhat hacky because needs to be in sess.
            if self.load_policy is not None:
                import joblib
                self.policy = joblib.load(self.load_policy)['policy']
            self.init_opt()
            # initialize uninitialized vars  (only initialize vars that were not loaded)
            uninit_vars = []
            for var in tf.global_variables():
                # note - this is hacky, may be better way to do this in newer TF.
                try:
                    sess.run(var)
                except tf.errors.FailedPreconditionError:
                    uninit_vars.append(var)
            sess.run(tf.variables_initializer(uninit_vars))

            self.start_worker()
            start_time = time.time()
            for itr in range(self.start_itr, self.n_itr):
                itr_start_time = time.time()
                with logger.prefix('itr #%d | ' % itr):
                    logger.log("Sampling set of tasks/goals for this meta-batch...")

                    # sample environment configuration
                    env = self.env
                    while not ('sample_env_params' in dir(env) or 'sample_goals' in dir(env)):
                        env = env._wrapped_env
                    if 'sample_goals' in dir(env):
                        learner_env_params = env.sample_goals(self.meta_batch_size)
                    elif 'sample_env_params':
                        learner_env_params = env.sample_env_params(self.meta_batch_size)


                    self.policy.switch_to_init_dist()  # Switch to pre-update policy

                    all_samples_data, all_paths = [], []
                    for step in range(self.num_grad_updates+1):
                        #if step > 0:
                        #    import pdb; pdb.set_trace() # test param_vals functions.
                        logger.log('** Step ' + str(step) + ' **')
                        logger.log("Obtaining samples...")
                        paths = self.obtain_samples(itr, reset_args=learner_env_params, log_prefix=str(step))
                        all_paths.append(paths)
                        logger.log("Processing samples...")
                        samples_data = {}
                        for key in paths.keys():  # the keys are the tasks
                            # don't log because this will spam the consol with every task.
                            samples_data[key] = self.process_samples(itr, paths[key], log=False)
                        all_samples_data.append(samples_data)
                        # for logging purposes
                        self.process_samples(itr, flatten_list(paths.values()), prefix=str(step), log=True)
                        logger.log("Logging diagnostics...")
                        self.log_diagnostics(flatten_list(paths.values()), prefix=str(step))
                        if step < self.num_grad_updates:
                            logger.log("Computing policy updates...")
                            self.policy.compute_updated_dists(samples_data)


                    logger.log("Optimizing policy...")
                    # This needs to take all samples_data so that it can construct graph for meta-optimization.
                    self.optimize_policy(itr, all_samples_data)
                    logger.log("Saving snapshot...")
                    params = self.get_itr_snapshot(itr, all_samples_data[-1])  # , **kwargs)
                    if self.store_paths:
                        params["paths"] = all_samples_data[-1]["paths"]
                    logger.save_itr_params(itr, params)
                    logger.log("Saved")
                    logger.record_tabular('Time', time.time() - start_time)
                    logger.record_tabular('ItrTime', time.time() - itr_start_time)

                    logger.dump_tabular(with_prefix=False)

                    # The rest is some example plotting code.
                    # Plotting code is useful for visualizing trajectories across a few different tasks.
                    if False and itr % 2 == 0 and self.env.observation_space.shape[0] <= 4: # point-mass
                        logger.log("Saving visualization of paths")
                        for ind in range(min(5, self.meta_batch_size)):
                            plt.clf()
                            plt.plot(learner_env_params[ind][0], learner_env_params[ind][1], 'k*', markersize=10)
                            plt.hold(True)

                            preupdate_paths = all_paths[0]
                            postupdate_paths = all_paths[-1]

                            pre_points = preupdate_paths[ind][0]['observations']
                            post_points = postupdate_paths[ind][0]['observations']
                            plt.plot(pre_points[:,0], pre_points[:,1], '-r', linewidth=2)
                            plt.plot(post_points[:,0], post_points[:,1], '-b', linewidth=1)

                            pre_points = preupdate_paths[ind][1]['observations']
                            post_points = postupdate_paths[ind][1]['observations']
                            plt.plot(pre_points[:,0], pre_points[:,1], '--r', linewidth=2)
                            plt.plot(post_points[:,0], post_points[:,1], '--b', linewidth=1)

                            pre_points = preupdate_paths[ind][2]['observations']
                            post_points = postupdate_paths[ind][2]['observations']
                            plt.plot(pre_points[:,0], pre_points[:,1], '-.r', linewidth=2)
                            plt.plot(post_points[:,0], post_points[:,1], '-.b', linewidth=1)

                            plt.plot(0,0, 'k.', markersize=5)
                            plt.xlim([-0.8, 0.8])
                            plt.ylim([-0.8, 0.8])
                            plt.legend(['goal', 'preupdate path', 'postupdate path'])
                            plt.savefig(osp.join(logger.get_snapshot_dir(), 'prepost_path'+str(ind)+'.png'))
                    elif False and itr % 2 == 0:  # swimmer or cheetah
                        logger.log("Saving visualization of paths")
                        for ind in range(min(5, self.meta_batch_size)):
                            plt.clf()
                            goal_vel = learner_env_params[ind]
                            plt.title('Swimmer paths, goal vel='+str(goal_vel))
                            plt.hold(True)

                            prepathobs = all_paths[0][ind][0]['observations']
                            postpathobs = all_paths[-1][ind][0]['observations']
                            plt.plot(prepathobs[:,0], prepathobs[:,1], '-r', linewidth=2)
                            plt.plot(postpathobs[:,0], postpathobs[:,1], '--b', linewidth=1)
                            plt.plot(prepathobs[-1,0], prepathobs[-1,1], 'r*', markersize=10)
                            plt.plot(postpathobs[-1,0], postpathobs[-1,1], 'b*', markersize=10)
                            plt.xlim([-1.0, 5.0])
                            plt.ylim([-1.0, 1.0])

                            plt.legend(['preupdate path', 'postupdate path'], loc=2)
                            plt.savefig(osp.join(logger.get_snapshot_dir(), 'swim1d_prepost_itr'+str(itr)+'_id'+str(ind)+'.pdf'))
        self.shutdown_worker()