def test_load(self, load_mode, last_epoch):
        snapshotter = Snapshotter()
        saved = snapshotter.load(self.temp_dir.name, load_mode)

        assert isinstance(saved['algo'], VPG)
        assert isinstance(saved['env'], GymEnv)
        assert isinstance(saved['algo'].policy, CategoricalMLPPolicy)
        assert saved['stats'].total_epoch == last_epoch
Example #2
0
    def test_conflicting_params(self):
        with pytest.raises(ValueError):
            Snapshotter(snapshot_dir=self.temp_dir.name,
                        snapshot_mode='last',
                        snapshot_gap=2)

        with pytest.raises(ValueError):
            Snapshotter(snapshot_dir=self.temp_dir.name,
                        snapshot_mode='gap_overwrite',
                        snapshot_gap=1)
Example #3
0
    def test_gap_overwrite(self):
        snapshotter = Snapshotter(self.temp_dir.name, 'gap_overwrite', 2)
        assert snapshotter.snapshot_dir == self.temp_dir.name
        assert snapshotter.snapshot_mode == 'gap_overwrite'
        assert snapshotter.snapshot_gap == 2

        snapshot_data = [{'testparam': 1}, {'testparam': 4}]
        snapshotter.save_itr_params(1, snapshot_data[0])
        snapshotter.save_itr_params(2, snapshot_data[1])

        filename = osp.join(self.temp_dir.name, 'params.pkl')
        assert osp.exists(filename)
        with open(filename, 'rb') as pkl_file:
            data = pickle.load(pkl_file)
            assert data == snapshot_data[1]
Example #4
0
    def __init__(self, snapshot_config=None, max_cpus=1):
        if snapshot_config:
            self._snapshotter = Snapshotter(snapshot_config.snapshot_dir,
                                            snapshot_config.snapshot_mode,
                                            snapshot_config.snapshot_gap)
        else:
            self._snapshotter = Snapshotter()

        if max_cpus > 1:
            from garage.sampler import singleton_pool
            singleton_pool.initialize(max_cpus)
        self.has_setup = False
        self.plot = False

        self._setup_args = None
        self.train_args = None
Example #5
0
    def test_snapshotter(self, mode, files):
        snapshotter = Snapshotter(self.temp_dir.name, mode, 2)

        assert snapshotter.snapshot_dir == self.temp_dir.name
        assert snapshotter.snapshot_mode == mode
        assert snapshotter.snapshot_gap == 2

        snapshot_data = [{'testparam': 1}, {'testparam': 4}]
        snapshotter.save_itr_params(1, snapshot_data[0])
        snapshotter.save_itr_params(2, snapshot_data[1])

        for f, num in files.items():
            filename = osp.join(self.temp_dir.name, f)
            assert osp.exists(filename)
            with open(filename, 'rb') as pkl_file:
                data = pickle.load(pkl_file)
                assert data == snapshot_data[num]
Example #6
0
def watch_atari(saved_dir, env=None, num_episodes=10):
    """Watch a trained agent play an atari game.

    Args:
        saved_dir (str): Directory containing the pickle file.
        env (str): Environment to run episodes on. If None, the pickled
            environment is used.
        num_episodes (int): Number of episodes to play. Note that when using
            the EpisodicLife wrapper, an episode is considered done when a
            life is lost. Defaults to 10.
    """
    snapshotter = Snapshotter()
    data = snapshotter.load(saved_dir)
    if env is not None:
        env = gym.make(env)
        env = Noop(env, noop_max=30)
        env = MaxAndSkip(env, skip=4)
        env = EpisodicLife(env)
        if 'FIRE' in env.unwrapped.get_action_meanings():
            env = FireReset(env)
        env = Grayscale(env)
        env = Resize(env, 84, 84)
        env = ClipReward(env)
        env = StackFrames(env, 4, axis=0)
        env = GymEnv(env)
    else:
        env = data['env']

    exploration_policy = data['algo'].exploration_policy
    exploration_policy.policy._qf.to('cpu')
    ep_rewards = np.asarray([])
    for _ in range(num_episodes):
        episode_data = rollout(env,
                               exploration_policy.policy,
                               animated=True,
                               pause_per_frame=0.02)
        ep_rewards = np.append(ep_rewards, np.sum(episode_data['rewards']))

    print('Average Reward {}'.format(np.mean(ep_rewards)))
Example #7
0
def main():
    snapshotter = Snapshotter()
    with tf.compat.v1.Session():
        print("loading model...")
        data = snapshotter.load(MODEL_PATH)
        print("model", data)
        policy = data['algo'].policy
        env = data['env']

        steps, max_steps = 0, 1000
        done = False
        env.render()
        obs = env.reset()  # The initial observation
        policy.reset()

        while steps < max_steps and not done:
            action = policy.get_action(obs)[0]
            obs, rew, done, _ = env.step(action)
            env.render(
            )  # Render the environment to see what's going on (optional)
            steps += 1

        env.close()
Example #8
0
class TestSanpshotter(unittest.TestCase):
    def setUp(self):
        self.snapshot_dir = tempfile.TemporaryDirectory()
        self.snapshotter = Snapshotter()

    def tearDown(self):
        self.snapshotter.reset()
        self.snapshot_dir.cleanup()

    def test_set_snapshot_dir(self):
        self.snapshotter.snapshot_dir = self.snapshot_dir.name
        assert self.snapshotter.snapshot_dir == self.snapshot_dir.name

    @tools.params(*configurations)
    def test_snapshotter(self, mode, files):
        self.snapshotter.snapshot_dir = self.snapshot_dir.name

        self.snapshotter.snapshot_mode = mode
        assert self.snapshotter.snapshot_mode == mode
        self.snapshotter.snapshot_gap = 2
        assert self.snapshotter.snapshot_gap == 2

        snapshot_data = [{'testparam': 1}, {'testparam': 4}]
        self.snapshotter.save_itr_params(1, snapshot_data[0])
        self.snapshotter.save_itr_params(2, snapshot_data[1])

        for f, num in files.items():
            filename = osp.join(self.snapshot_dir.name, f)
            assert osp.exists(filename)
            with open(filename, 'rb') as pkl_file:
                data = pickle.load(pkl_file)
                assert data == snapshot_data[num]

    def test_invalid_snapshot_mode(self):
        with self.assertRaises(ValueError):
            self.snapshotter.snapshot_dir = self.snapshot_dir.name
            self.snapshotter.snapshot_mode = 'invalid'
            self.snapshotter.save_itr_params(2, {'testparam': 'invalid'})
Example #9
0
class TestSnapshotter:
    def setup_method(self):
        self.snapshot_dir = tempfile.TemporaryDirectory()
        self.snapshotter = Snapshotter()

    def teardown_method(self):
        self.snapshotter.reset()
        self.snapshot_dir.cleanup()

    def test_set_snapshot_dir(self):
        self.snapshotter.snapshot_dir = self.snapshot_dir.name
        assert self.snapshotter.snapshot_dir == self.snapshot_dir.name

    @pytest.mark.parametrize('mode, files', [*configurations])
    def test_snapshotter(self, mode, files):
        self.snapshotter.snapshot_dir = self.snapshot_dir.name

        self.snapshotter.snapshot_mode = mode
        assert self.snapshotter.snapshot_mode == mode
        self.snapshotter.snapshot_gap = 2
        assert self.snapshotter.snapshot_gap == 2

        snapshot_data = [{'testparam': 1}, {'testparam': 4}]
        self.snapshotter.save_itr_params(1, snapshot_data[0])
        self.snapshotter.save_itr_params(2, snapshot_data[1])

        for f, num in files.items():
            filename = osp.join(self.snapshot_dir.name, f)
            assert osp.exists(filename)
            with open(filename, 'rb') as pkl_file:
                data = pickle.load(pkl_file)
                assert data == snapshot_data[num]

    def test_invalid_snapshot_mode(self):
        with pytest.raises(ValueError):
            self.snapshotter.snapshot_dir = self.snapshot_dir.name
            self.snapshotter.snapshot_mode = 'invalid'
            self.snapshotter.save_itr_params(2, {'testparam': 'invalid'})
 def test_load_with_invalid_load_mode(self):
     snapshotter = Snapshotter()
     with pytest.raises(ValueError):
         snapshotter.load(self.temp_dir.name, 'foo')
Example #11
0
 def test_invalid_snapshot_mode(self):
     with pytest.raises(ValueError):
         snapshotter = Snapshotter(snapshot_dir=self.temp_dir.name,
                                   snapshot_mode='invalid')
         snapshotter.save_itr_params(2, {'testparam': 'invalid'})
Example #12
0
 def setUp(self):
     self.snapshot_dir = tempfile.TemporaryDirectory()
     self.snapshotter = Snapshotter()
Example #13
0
class LocalRunner:
    """Base class of local runner.

    Use Runner.setup(algo, env) to setup algorithm and environement for runner
    and Runner.train() to start training.

    Args:
        snapshot_config (garage.experiment.SnapshotConfig): The snapshot
            configuration used by LocalRunner to create the snapshotter.
            If None, it will create one with default settings.
        max_cpus (int): The maximum number of parallel sampler workers.

    Examples:
        # to train
        runner = LocalRunner()
        env = Env(...)
        policy = Policy(...)
        algo = Algo(
                env=env,
                policy=policy,
                ...)
        runner.setup(algo, env)
        runner.train(n_epochs=100, batch_size=4000)

        # to resume immediately.
        runner = LocalRunner()
        runner.restore(resume_from_dir)
        runner.resume()

        # to resume with modified training arguments.
        runner = LocalRunner()
        runner.restore(resume_from_dir)
        runner.resume(n_epochs=20)

    """
    def __init__(self, snapshot_config=None, max_cpus=1):
        if snapshot_config:
            self._snapshotter = Snapshotter(snapshot_config.snapshot_dir,
                                            snapshot_config.snapshot_mode,
                                            snapshot_config.snapshot_gap)
        else:
            self._snapshotter = Snapshotter()

        if max_cpus > 1:
            from garage.sampler import singleton_pool
            singleton_pool.initialize(max_cpus)
        self.has_setup = False
        self.plot = False

        self._setup_args = None
        self.train_args = None

    def setup(self, algo, env, sampler_cls=None, sampler_args=None):
        """Set up runner for algorithm and environment.

        This method saves algo and env within runner and creates a sampler.

        Note:
            After setup() is called all variables in session should have been
            initialized. setup() respects existing values in session so
            policy weights can be loaded before setup().

        Args:
            algo (garage.np.algos.RLAlgorithm): An algorithm instance.
            env (garage.envs.GarageEnv): An environement instance.
            sampler_cls (garage.sampler.Sampler): A sampler class.
            sampler_args (dict): Arguments to be passed to sampler constructor.

        """
        self.algo = algo
        self.env = env
        self.policy = self.algo.policy

        if sampler_args is None:
            sampler_args = {}
        if sampler_cls is None:
            sampler_cls = algo.sampler_cls
        self.sampler = sampler_cls(env, self.policy, **sampler_args)

        self.has_setup = True

        self._setup_args = types.SimpleNamespace(sampler_cls=sampler_cls,
                                                 sampler_args=sampler_args)

    def _start_worker(self):
        """Start Plotter and Sampler workers."""
        self.sampler.start_worker()
        if self.plot:
            from garage.tf.plotter import Plotter
            self.plotter = Plotter(self.env, self.policy)
            self.plotter.start()

    def _shutdown_worker(self):
        """Shutdown Plotter and Sampler workers."""
        self.sampler.shutdown_worker()
        if self.plot:
            self.plotter.close()

    def obtain_samples(self, itr, batch_size):
        """Obtain one batch of samples.

        Args:
            itr(int): Index of iteration (epoch).
            batch_size(int): Number of steps in batch.
                This is a hint that the sampler may or may not respect.

        Returns:
            One batch of samples.

        """
        show_pbar = False
        if self.train_args.n_epoch_cycles == 1:
            logger.log('Obtaining samples...')
            show_pbar = True
        # TODO refactor logging to clean up args of obtain_samples
        return self.sampler.obtain_samples(itr, log=True, show_pbar=show_pbar)

    def save(self, epoch, paths=None):
        """Save snapshot of current batch.

        Args:
            itr(int): Index of iteration (epoch).
            paths(dict): Batch of samples after preprocessed. If None,
                no paths will be logged to the snapshot.

        """
        if not self.has_setup:
            raise Exception('Use setup() to setup runner before saving.')

        logger.log('Saving snapshot...')

        params = dict()
        # Save arguments
        params['setup_args'] = self._setup_args
        params['train_args'] = self.train_args

        # Save states
        params['env'] = self.env
        # TODO: add this back
        # params['algo'] = self.algo
        if paths:
            params['paths'] = paths
        params['last_epoch'] = epoch
        self._snapshotter.save_itr_params(epoch, params)

        logger.log('Saved')

    def restore(self, from_dir, from_epoch='last'):
        """Restore experiment from snapshot.

        Args:
            from_dir (str): Directory of the pickle file
                to resume experiment from.
            from_epoch (str or int): The epoch to restore from.
                Can be 'first', 'last' or a number.
                Not applicable when snapshot_mode='last'.

        Returns:
            A SimpleNamespace for train()'s arguments.

        """
        saved = self._snapshotter.load(from_dir, from_epoch)

        self._setup_args = saved['setup_args']
        self.train_args = saved['train_args']

        self.setup(env=saved['env'],
                   algo=saved['algo'],
                   sampler_cls=self._setup_args.sampler_cls,
                   sampler_args=self._setup_args.sampler_args)

        n_epochs = self.train_args.n_epochs
        last_epoch = saved['last_epoch']
        n_epoch_cycles = self.train_args.n_epoch_cycles
        batch_size = self.train_args.batch_size
        store_paths = self.train_args.store_paths
        pause_for_plot = self.train_args.pause_for_plot

        fmt = '{:<20} {:<15}'
        logger.log('Restore from snapshot saved in %s' %
                   self._snapshotter.snapshot_dir)
        logger.log(fmt.format('Train Args', 'Value'))
        logger.log(fmt.format('n_epochs', n_epochs))
        logger.log(fmt.format('last_epoch', last_epoch))
        logger.log(fmt.format('n_epoch_cycles', n_epoch_cycles))
        logger.log(fmt.format('batch_size', batch_size))
        logger.log(fmt.format('store_paths', store_paths))
        logger.log(fmt.format('pause_for_plot', pause_for_plot))

        self.train_args.start_epoch = last_epoch + 1
        return copy.copy(self.train_args)

    def log_diagnostics(self, pause_for_plot=False):
        """Log diagnostics.

        Args:
            pause_for_plot(bool): Pause for plot.

        """
        logger.log('Time %.2f s' % (time.time() - self._start_time))
        logger.log('EpochTime %.2f s' % (time.time() - self._itr_start_time))
        logger.log(tabular)
        if self.plot:
            self.plotter.update_plot(self.policy, self.algo.max_path_length)
            if pause_for_plot:
                input('Plotting evaluation run: Press Enter to " "continue...')

    def train(self,
              n_epochs,
              batch_size,
              n_epoch_cycles=1,
              plot=False,
              store_paths=False,
              pause_for_plot=False):
        """Start training.

        Args:
            n_epochs(int): Number of epochs.
            batch_size(int): Number of environment steps in one batch.
            n_epoch_cycles(int): Number of batches of samples in each epoch.
                This is only useful for off-policy algorithm.
                For on-policy algorithm this value should always be 1.
            plot(bool): Visualize policy by doing rollout after each epoch.
            store_paths(bool): Save paths in snapshot.
            pause_for_plot(bool): Pause for plot.

        Returns:
            The average return in last epoch cycle.

        """
        if not self.has_setup:
            raise Exception('Use setup() to setup runner before training.')

        if hasattr(self.algo, 'n_epoch_cycles'):
            n_epoch_cycles = self.algo.n_epoch_cycles

        # Save arguments for restore
        self.train_args = types.SimpleNamespace(n_epochs=n_epochs,
                                                n_epoch_cycles=n_epoch_cycles,
                                                batch_size=batch_size,
                                                plot=plot,
                                                store_paths=store_paths,
                                                pause_for_plot=pause_for_plot,
                                                start_epoch=0)

        self.plot = plot

        return self.algo.train(self, batch_size)

    def step_epochs(self):
        """Generator for training.

        This function serves as a generator. It is used to separate
        services such as snapshotting, sampler control from the actual
        training loop. It is used inside train() in each algorithm.

        The generator initializes two variables: `self.step_itr` and
        `self.step_path`. To use the generator, these two have to be
        updated manually in each epoch, as the example shows below.

        Yields:
            int: The next training epoch.

        Examples:
            for epoch in runner.step_epochs():
                runner.step_path = runner.obtain_samples(...)
                self.train_once(...)
                runner.step_itr += 1

        """
        try:
            self._start_worker()
            self._start_time = time.time()
            self.step_itr = (self.train_args.start_epoch *
                             self.train_args.n_epoch_cycles)
            self.step_path = None

            for epoch in range(self.train_args.start_epoch,
                               self.train_args.n_epochs):
                self._itr_start_time = time.time()
                with logger.prefix('epoch #%d | ' % epoch):
                    yield epoch
                    save_path = (self.step_path
                                 if self.train_args.store_paths else None)
                    self.save(epoch, save_path)
                    self.log_diagnostics(self.train_args.pause_for_plot)
                    logger.dump_all(self.step_itr)
                    tabular.clear()
        finally:
            self._shutdown_worker()

    def resume(self,
               n_epochs=None,
               batch_size=None,
               n_epoch_cycles=None,
               plot=None,
               store_paths=None,
               pause_for_plot=None):
        """Resume from restored experiment.

        This method provides the same interface as train().

        If not specified, an argument will default to the
        saved arguments from the last call to train().

        Returns:
            The average return in last epoch cycle.

        """
        if self.train_args is None:
            raise Exception('You must call restore() before resume().')

        self.train_args.n_epochs = n_epochs or self.train_args.n_epochs
        self.train_args.batch_size = batch_size or self.train_args.batch_size
        self.train_args.n_epoch_cycles = (n_epoch_cycles
                                          or self.train_args.n_epoch_cycles)

        if plot is not None:
            self.train_args.plot = plot
        if store_paths is not None:
            self.train_args.store_paths = store_paths
        if pause_for_plot is not None:
            self.train_args.pause_for_plot = pause_for_plot

        return self.algo.train(self, batch_size)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        pass
Example #14
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--render',
                        action='store_true',
                        help='Render trajectories')
    parser.add_argument('--random', action='store_true', help='Noisy actions')
    parser.add_argument('--num_samples',
                        type=int,
                        default=int(1e5),
                        help='Num samples to collect')
    args = parser.parse_args()

    buffer_data = reset_data()
    snapshotter = Snapshotter()

    with tf.compat.v1.Session():  # optional, only for TensorFlow
        data = snapshotter.load(MODEL_PATH)
        policy = data['algo'].policy
        env = data['env']

        steps, max_steps = 0, 1000
        if args.render:
            env.render()
        obs = env.reset()  # The initial observation
        policy.reset()
        done = False
        ts = 0
        rews = []

        for _ in range(args.num_samples):
            obs = env.reset()  # The initial observation
            policy.reset()
            done = False
            rew = 0.0
            ts = 0
            tot_rew = 0
            # if _ % 1000 == 0:
            print('episode: ', _)

            for _ in range(max_steps):
                if args.render:
                    env.render()

                act, prob = policy.get_action(obs)

                # act[0] is the actual action, while the second tuple is the done variable. Inspiration:
                # https://github.com/lcswillems/rl-starter-files/blob/3c7289765883ca681e586b51acf99df1351f8ead/utils/agent.py#L47

                append_data(buffer_data, obs, act, prob, done, rew)
                new_obs, rew, done, _ = env.step(act)  # why [0] ?
                ts += 1
                tot_rew += rew

                if done:
                    # reset target here!
                    random_act = env.action_space.sample()
                    infos = {
                        'mean': np.random.rand(env.action_space.shape[0]),
                        'log_std': np.random.rand(env.action_space.shape[0])
                    }  # random action info
                    append_data(buffer_data, new_obs, random_act, infos, done,
                                rew)
                    break

                else:
                    # continue by setting current obs
                    obs = new_obs

            rews.append(tot_rew)

        print('Avg Rew: ', np.mean(rews))
        fname = 'generated_hopper_probs.hdf5'
        dataset = h5py.File(fname, 'w')
        npify(buffer_data)
        for key in buffer_data:
            dataset.create_dataset(key,
                                   data=buffer_data[key],
                                   compression='gzip')

        env.close()