Exemplo n.º 1
0
    def train(self, runner):
        """Start training.

        Parameters
        ----------
        runner : :py:class:`garage.experiment.LocalRunner <garage:garage.experiment.LocalRunner>`
            ``LocalRunner`` is passed to give algorithm the access to ``runner.step_epochs()``, which provides services
            such as snapshotting and sampler control.
        """
        self.initial()

        for itr in runner.step_epochs():
            all_paths = {}
            for p in range(self.pop_size):
                with logger.prefix('idv #%d | ' % p):
                    logger.log("Updating Params")
                    self.set_params(itr, p)
                    logger.log("Obtaining samples...")
                    paths = self.obtain_samples(itr, runner)
                    logger.log("Processing samples...")
                    samples_data = self.process_samples(itr, paths)

                    # all_paths[p]=paths
                    all_paths[p] = samples_data

                    # logger.log("Logging diagnostics...")
                    # self.log_diagnostics(paths)

            logger.log("Optimizing Population...")
            self.optimize_policy(itr, all_paths)
            self.step_size = self.step_size * self.step_size_anneal
            self.record_tabular(itr)
            runner.step_itr += 1
        return None
Exemplo n.º 2
0
    def _train(self,
               n_epochs,
               n_epoch_cycles,
               batch_size,
               plot,
               store_paths,
               pause_for_plot,
               start_epoch=0):
        """Start actual training.

        Args:
            n_epochs(int): Number of epochs.
            n_epoch_cycles(int): Number of batches of samples in each epoch.
                This is only useful for off-policy algorithm.
                For on-policy algorithm this value should always be 1.
            batch_size(int): Number of steps in batch.
            plot(bool): Visualize policy by doing rollout after each epoch.
            store_paths(bool): Save paths in snapshot.
            pause_for_plot(bool): Pause for plot.
            start_epoch: (internal) The starting epoch.
                Use for experiment resuming.

        Returns:
            The average return in last epoch cycle.

        """
        assert self.has_setup, ('Use Runner.setup() to setup runner before '
                                'training.')

        # Save arguments for restore
        self.train_args = SimpleNamespace(n_epochs=n_epochs,
                                          n_epoch_cycles=n_epoch_cycles,
                                          batch_size=batch_size,
                                          plot=plot,
                                          store_paths=store_paths,
                                          pause_for_plot=pause_for_plot,
                                          start_epoch=start_epoch)

        self.start_worker()

        self.start_time = time.time()
        itr = start_epoch * n_epoch_cycles

        last_return = None
        for epoch in range(start_epoch, n_epochs):
            self.itr_start_time = time.time()
            paths = None
            with logger.prefix('epoch #%d | ' % epoch):
                for cycle in range(n_epoch_cycles):
                    paths = self.obtain_samples(itr, batch_size)
                    last_return = self.algo.train_once(itr, paths)
                    itr += 1
                self.save(epoch, paths if store_paths else None)
                self.log_diagnostics(pause_for_plot)
                logger.dump_all(itr)
                tabular.clear()

        self.shutdown_worker()

        return last_return
Exemplo n.º 3
0
    def train(self, runner):
        self.initial()

        for itr in runner.step_epochs():
            all_paths = {}
            for p in range(self.pop_size):
                with logger.prefix('idv #%d | ' % p):
                    logger.log("Updating Params")
                    self.set_params(itr, p)
                    logger.log("Obtaining samples...")
                    paths = self.obtain_samples(itr, runner)
                    logger.log("Processing samples...")
                    samples_data = self.process_samples(itr, paths)

                    # all_paths[p]=paths
                    all_paths[p] = samples_data

                    # logger.log("Logging diagnostics...")
                    # self.log_diagnostics(paths)

            logger.log("Optimizing Population...")
            self.optimize_policy(itr, all_paths)
            self.step_size = self.step_size * self.step_size_anneal
            self.record_tabular(itr)
            runner.step_itr += 1
        return None
Exemplo n.º 4
0
    def train(self):
        plotter = Plotter()
        if self.plot:
            plotter.init_plot(self.env, self.policy)
        self.start_worker()
        self.init_opt()
        for itr in range(self.current_itr, self.n_itr):
            with logger.prefix('itr #{} | '.format(itr)):
                paths = self.sampler.obtain_samples(itr)
                samples_data = self.sampler.process_samples(itr, paths)
                self.log_diagnostics(paths)
                self.optimize_policy(itr, samples_data)
                logger.log('Saving snapshot...')
                params = self.get_itr_snapshot(itr, samples_data)
                self.current_itr = itr + 1
                params['algo'] = self
                if self.store_paths:
                    params['paths'] = samples_data['paths']
                snapshotter.save_itr_params(itr, params)
                logger.log('saved')
                logger.log(tabular)
                if self.plot:
                    plotter.update_plot(self.policy, self.max_path_length)
                    if self.pause_for_plot:
                        input('Plotting evaluation run: Press Enter to '
                              'continue...')

        plotter.close()
        self.shutdown_worker()
Exemplo n.º 5
0
    def train(self, sess=None):
        address = ('localhost', 6000)
        conn = Client(address)
        last_average_return = None
        try:
            created_session = True if (sess is None) else False
            if sess is None:
                sess = tf.compat.v1.Session()
                sess.__enter__()

            sess.run(tf.compat.v1.global_variables_initializer())
            conn.send(ExpLifecycle.START)
            self.start_worker(sess)
            start_time = time.time()
            for itr in range(self.start_itr, self.n_itr):
                itr_start_time = time.time()
                with logger.prefix('itr #%d | ' % itr):
                    logger.log('Obtaining samples...')
                    conn.send(ExpLifecycle.OBTAIN_SAMPLES)
                    paths = self.obtain_samples(itr)
                    logger.log('Processing samples...')
                    conn.send(ExpLifecycle.PROCESS_SAMPLES)
                    samples_data = self.process_samples(itr, paths)
                    last_average_return = samples_data['average_return']
                    logger.log('Logging diagnostics...')
                    self.log_diagnostics(paths)
                    logger.log('Optimizing policy...')
                    conn.send(ExpLifecycle.OPTIMIZE_POLICY)
                    self.optimize_policy(itr, samples_data)
                    logger.log('Saving snapshot...')
                    params = self.get_itr_snapshot(itr)
                    if self.store_paths:
                        params['paths'] = samples_data['paths']
                    snapshotter.save_itr_params(itr, params)
                    logger.log('Saved')
                    tabular.record('Time', time.time() - start_time)
                    tabular.record('ItrTime', time.time() - itr_start_time)
                    logger.log(tabular)
                    if self.plot:
                        conn.send(ExpLifecycle.UPDATE_PLOT)
                        self.plotter.update_plot(self.policy,
                                                 self.max_path_length)
                        if self.pause_for_plot:
                            input('Plotting evaluation run: Press Enter to '
                                  'continue...')

            conn.send(ExpLifecycle.SHUTDOWN)
            self.shutdown_worker()
            if created_session:
                sess.close()
        finally:
            conn.close()
        return last_average_return
Exemplo n.º 6
0
    def step_epochs(self):
        """Step through each epoch.

        This function returns a magic generator. When iterated through, this
        generator automatically performs services such as snapshotting and log
        management. It is used inside train() in each algorithm.

        The generator initializes two variables: `self.step_itr` and
        `self.step_path`. To use the generator, these two have to be
        updated manually in each epoch, as the example shows below.

        Yields:
            int: The next training epoch.

        Examples:
            for epoch in runner.step_epochs():
                runner.step_path = runner.obtain_samples(...)
                self.train_once(...)
                runner.step_itr += 1

        """
        self._start_worker()
        self._start_time = time.time()
        self.step_itr = self._stats.total_itr
        self.step_path = None

        # Used by integration tests to ensure examples can run one epoch.
        n_epochs = int(
            os.environ.get('GARAGE_EXAMPLE_TEST_N_EPOCHS',
                           self._train_args.n_epochs))

        logger.log('Obtaining samples...')

        suffix = str(uuid.uuid1())
        src = Path(self._snapshotter.snapshot_dir)
        dstfile = f"{src.name}_{suffix}.tar.xz"
        for epoch in range(self._train_args.start_epoch, n_epochs):
            self._itr_start_time = time.time()
            with logger.prefix('epoch #%d | ' % epoch):
                yield epoch
                save_path = (self.step_path
                             if self._train_args.store_paths else None)

                self._stats.last_path = save_path
                self._stats.total_epoch = epoch
                self._stats.total_itr = self.step_itr

                self.save(epoch)
                self.log_diagnostics(self._train_args.pause_for_plot)
                logger.dump_all(self.step_itr)
                tabular.clear()
Exemplo n.º 7
0
    def train(self, num_iter, dump=False):

        start = time.time()
        for i in range(num_iter):
            with logger.prefix(' | Iteration {} |'.format(i)):
                t1 = time.time()
                self.train_step()
                t2 = time.time()
                print('total time of one step', t2 - t1)
                print('iter ', i, ' done')
                if dump:
                    logger.log(tabular)
                    logger.dump_all(i)
                    tabular.clear()
        return
Exemplo n.º 8
0
    def step_epochs(self):
        """Step through each epoch.

        This function returns a magic generator. When iterated through, this
        generator automatically performs services such as snapshotting and log
        management. It is used inside train() in each algorithm.

        The generator initializes two variables: `self.step_itr` and
        `self.step_path`. To use the generator, these two have to be
        updated manually in each epoch, as the example shows below.

        Yields:
            int: The next training epoch.

        Examples:
            for epoch in runner.step_epochs():
                runner.step_path = runner.obtain_samples(...)
                self.train_once(...)
                runner.step_itr += 1

        """
        try:
            self._start_worker()
            self._start_time = time.time()
            self.step_itr = (self.train_args.start_epoch *
                             self.train_args.n_epoch_cycles)
            self.step_path = None

            for epoch in range(self.train_args.start_epoch,
                               self.train_args.n_epochs):
                self._itr_start_time = time.time()
                with logger.prefix('epoch #%d | ' % epoch):
                    yield epoch
                    save_path = (self.step_path
                                 if self.train_args.store_paths else None)
                    print("save_path:", save_path)
                    self.save(epoch, save_path)
                    self.log_diagnostics(self.train_args.pause_for_plot)
                    logger.dump_all(self.step_itr)
                    tabular.clear()
        finally:
            self._shutdown_worker()
Exemplo n.º 9
0
    def step_epochs(self):
        """Generator for training.

        This function serves as a generator. It is used to separate
        services such as snapshotting, sampler control from the actual
        training loop. It is used inside train() in each algorithm.

        The generator initializes two variables: `self.step_itr` and
        `self.step_path`. To use the generator, these two have to be
        updated manually in each epoch, as the example shows below.

        Yields:
            int: The next training epoch.

        Examples:
            for epoch in runner.step_epochs():
                runner.step_path = runner.obtain_samples(...)
                self.train_once(...)
                runner.step_itr += 1

        """
        try:
            self._start_worker()
            self._start_time = time.time()
            self.step_itr = (self.train_args.start_epoch *
                             self.train_args.n_epoch_cycles)
            self.step_path = None

            for epoch in range(self.train_args.start_epoch,
                               self.train_args.n_epochs):
                self._itr_start_time = time.time()
                with logger.prefix('epoch #%d | ' % epoch):
                    yield epoch
                    save_path = (self.step_path
                                 if self.train_args.store_paths else None)
                    self.save(epoch, save_path)
                    self.log_diagnostics(self.train_args.pause_for_plot)
                    logger.dump_all(self.step_itr)
                    tabular.clear()
        finally:
            self._shutdown_worker()
Exemplo n.º 10
0
    def train(self):
        address = ('localhost', 6000)
        conn = Client(address)
        try:
            plotter = Plotter()
            if self.plot:
                plotter.init_plot(self.env, self.policy)
            conn.send(ExpLifecycle.START)
            self.start_worker()
            self.init_opt()
            for itr in range(self.current_itr, self.n_itr):
                with logger.prefix('itr #{} | '.format(itr)):
                    conn.send(ExpLifecycle.OBTAIN_SAMPLES)
                    paths = self.sampler.obtain_samples(itr)
                    conn.send(ExpLifecycle.PROCESS_SAMPLES)
                    samples_data = self.sampler.process_samples(itr, paths)
                    self.log_diagnostics(paths)
                    conn.send(ExpLifecycle.OPTIMIZE_POLICY)
                    self.optimize_policy(itr, samples_data)
                    logger.log('saving snapshot...')
                    params = self.get_itr_snapshot(itr, samples_data)
                    self.current_itr = itr + 1
                    params['algo'] = self
                    if self.store_paths:
                        params['paths'] = samples_data['paths']
                    snapshotter.save_itr_params(itr, params)
                    logger.log('saved')
                    logger.log(tabular)
                    if self.plot:
                        conn.send(ExpLifecycle.UPDATE_PLOT)
                        plotter.update_plot(self.policy, self.max_path_length)
                        if self.pause_for_plot:
                            input('Plotting evaluation run: Press Enter to '
                                  'continue...')

            conn.send(ExpLifecycle.SHUTDOWN)
            plotter.close()
            self.shutdown_worker()
        finally:
            conn.close()
Exemplo n.º 11
0
    def train(self,
              n_epochs,
              batch_size=None,
              plot=False,
              store_episodes=False,
              pause_for_plot=False):
        """Start training.

        Args:
            n_epochs (int): Number of epochs.
            batch_size (int or None): Number of environment steps in one batch.
            plot (bool): Visualize an episode from the policy after each epoch.
            store_episodes (bool): Save episodes in snapshot.
            pause_for_plot (bool): Pause for plot.

        Raises:
            NotSetupError: If train() is called before setup().

        Returns:
            float: The average return in last epoch cycle.

        """
        self.batch_size = batch_size
        self.store_episodes = store_episodes
        self.pause_for_plot = pause_for_plot
        if not self._has_setup:
            raise NotSetupError(
                'Use setup() to setup trainer before training.')

        self._plot = plot

        returns = []
        for itr in range(self.start_itr, self.n_itr):
            with logger.prefix(f'itr #{itr} | '):

                # train policy
                self._algo.train(self)

                # compute irl and update reward function
                logger.log('Obtaining paths...')
                paths = self.obtain_samples(itr)
                logger.log('Processing paths...')
                paths = self._train_irl(paths, itr=itr)
                samples_data = self.process_samples(itr, paths)

                logger.log('Logging diagnostics...')
                logger.log('Time %.2f s' % (time.time() - self._start_time))
                logger.log('EpochTime %.2f s' %
                           (time.time() - self._itr_start_time))
                tabular.record('TotalEnvSteps', self._stats.total_env_steps)
                self.log_diagnostics(paths)
                logger.log('Optimizing policy...')

                logger.log('Saving snapshot...')
                self.save(itr, paths=paths)
                logger.log('Saved')
                tabular.record('Time', time.time() - self._start_time)
                tabular.record('ItrTime', time.time() - self._itr_start_time)
                logger.dump_all(self.step_itr)
                tabular.clear()

        self._shutdown_worker()

        return