Example #1
0
    def iteration(self, sample_lists, itr):
        # Store the samples and evaluate the costs.
        for m in range(self.M):
            self.cur[m].sample_list = sample_lists[m]
            self._eval_cost(m)

        with Timer(self.timers, 'pol_update'):
            self._update_policy()

        # Prepare for next iteration
        self._advance_iteration_variables()
Example #2
0
    def _update_trajectories(self):
        """Compute new linear Gaussian controllers."""
        if not hasattr(self, 'new_traj_distr'):
            self.new_traj_distr = [
                self.cur[cond].traj_distr for cond in range(self.M)
            ]
        with Timer(self.timers, 'traj_opt'):
            for cond in range(self.M):
                self.new_traj_distr[cond], self.cur[cond].eta, self.new_mu[
                    cond], self.new_sigma[cond] = self.traj_opt_update(cond)

        self.visualize_local_policy(0)
Example #3
0
    def iteration(self, sample_lists, _):
        """
        Run iteration of MDGPS-based guided policy search.

        Args:
            sample_lists: List of SampleList objects for each condition.
            _: to match parent class
        """
        # Store the samples and evaluate the costs.
        for m in range(self.M):
            self.cur[m].sample_list = sample_lists[m]
            self._eval_cost(m)

        # Update dynamics linearizations.
        self._update_dynamics()

        # On the first iteration, need to catch policy up to init_traj_distr.
        if self.iteration_count == 0:
            self.new_traj_distr = [
                self.cur[cond].traj_distr for cond in range(self.M)
            ]
            self._update_policy()

        # Update policy linearizations.
        with Timer(self.algorithm.timers, 'pol_lin'):
            for m in range(self.M):
                self._update_policy_fit(m)

        # C-step
        if self.iteration_count > 0:
            self._stepadjust()
        self._update_trajectories()

        # S-step
        with Timer(self.algorithm.timers, 'pol_update'):
            self._update_policy()

        # Prepare for next iteration
        self._advance_iteration_variables()
Example #4
0
    def run(self):
        """Runs training by alternatively taking samples and optimizing the policy."""
        if 'load_model' in self._hyperparams:
            self.iteration_count = self._hyperparams['load_model'][1]
            self.algorithm.policy_opt.iteration_count = self.iteration_count
            self.algorithm.policy_opt.restore_model(
                *self._hyperparams['load_model'])

            # Global policy static resets
            if self._hyperparams['num_pol_samples_static'] > 0:
                self.export_samples(self._take_policy_samples(
                    N=self._hyperparams['num_pol_samples_static'],
                    pol=self.algorithm.policy_opt.policy,
                    rnd=False),
                                    '_pol-static',
                                    visualize=True)

            return

        for itr in range(self._hyperparams['iterations']):
            self.iteration_count = itr
            if hasattr(self.algorithm, 'traj_opt'):
                self.algorithm.traj_opt.iteration_count = itr
            if hasattr(self.algorithm, 'policy_opt'):
                self.algorithm.policy_opt.iteration_count = itr

            print("*** Iteration %02d ***" % itr)
            if itr == 0 and 'load_initial_samples' in self._hyperparams:
                # Load trajectory samples
                print('Loading initial samples ...')
                sample_files = self._hyperparams['load_initial_samples']
                traj_sample_lists = [[] for _ in range(self.algorithm.M)]
                for sample_file in sample_files:
                    data = np.load(sample_file)
                    X, U = data['X'], data['U']
                    assert X.shape[0] == self.algorithm.M
                    for m in range(self.algorithm.M):
                        for n in range(X.shape[1]):
                            traj_sample_lists[m].append(
                                self.agent.pack_sample(X[m, n], U[m, n]))
                traj_sample_lists = [
                    SampleList(traj_samples)
                    for traj_samples in traj_sample_lists
                ]
            else:
                # Take trajectory samples
                with Timer(self.algorithm.timers, 'sampling'):
                    for cond in self._train_idx:
                        for i in trange(self._hyperparams['num_samples'],
                                        desc='Taking samples'):
                            self._take_sample(cond, i)
                traj_sample_lists = [
                    self.agent.get_samples(cond,
                                           -self._hyperparams['num_samples'])
                    for cond in self._train_idx
                ]
            self.export_samples(traj_sample_lists, visualize=True)

            # Iteration
            with Timer(self.algorithm.timers, 'iteration'):
                self.algorithm.iteration(traj_sample_lists, itr)
            self.export_dynamics()
            self.export_controllers()
            self.export_times()
            if hasattr(self.algorithm, 'policy_opt') and hasattr(
                    self.algorithm.policy_opt, 'store_model'):
                self.algorithm.policy_opt.store_model()

            # Sample learned policies for visualization

            # LQR policies static resets
            if self._hyperparams['num_lqr_samples_static'] > 0:
                self.export_samples(self._take_policy_samples(
                    N=self._hyperparams['num_lqr_samples_static'],
                    pol=None,
                    rnd=False),
                                    '_lqr-static',
                                    visualize=True)

            # LQR policies random resets
            if self._hyperparams['num_lqr_samples_random'] > 0:
                self.export_samples(self._take_policy_samples(
                    N=self._hyperparams['num_lqr_samples_random'],
                    pol=None,
                    rnd=True),
                                    '_lqr-random',
                                    visualize=True)

            # LQR policies state noise
            if self._hyperparams['num_lqr_samples_random'] > 0:
                self.export_samples(self._take_policy_samples(
                    N=self._hyperparams['num_lqr_samples_random'],
                    pol=None,
                    rnd=False,
                    randomize_initial_state=24),
                                    '_lqr-static-randomized',
                                    visualize=True)

            if hasattr(self.algorithm, 'policy_opt'):
                # Global policy static resets
                if self._hyperparams['num_pol_samples_static'] > 0:
                    self.export_samples(self._take_policy_samples(
                        N=self._hyperparams['num_pol_samples_static'],
                        pol=self.algorithm.policy_opt.policy,
                        rnd=False),
                                        '_pol-static',
                                        visualize=True)

                # Global policy random resets
                if self._hyperparams['num_pol_samples_random'] > 0:
                    self.export_samples(self._take_policy_samples(
                        N=self._hyperparams['num_pol_samples_random'],
                        pol=self.algorithm.policy_opt.policy,
                        rnd=True),
                                        '_pol-random',
                                        visualize=True)

                # Global policy state noise
                if self._hyperparams['num_pol_samples_random'] > 0:
                    self.export_samples(self._take_policy_samples(
                        N=self._hyperparams['num_pol_samples_random'],
                        pol=self.algorithm.policy_opt.policy,
                        rnd=False,
                        randomize_initial_state=24),
                                        '_pol-static-randomized',
                                        visualize=True)

            self.visualize_training_progress()
Example #5
0
    def run(self, session_id, itr_load=None):
        """
        Run training by iteratively sampling and taking an iteration.
        Args:
            itr_load: If specified, loads algorithm state from that
                iteration, and resumes training at the next iteration.
        Returns: None
        """
        itr_start = self._initialize(itr_load)
        if session_id is not None:
            self.session_id = session_id

        for itr in range(itr_start, self._hyperparams['iterations']):
            self.iteration_count = itr
            if hasattr(self.algorithm, 'policy_opt'):
                self.algorithm.policy_opt.iteration_count = itr

            print("*** Iteration %02d ***" % itr)
            # Take trajectory samples
            with Timer(self.algorithm.timers, 'sampling'):
                for cond in self._train_idx:
                    for i in trange(self._hyperparams['num_samples'],
                                    desc='Taking samples'):
                        self._take_sample(itr, cond, i)
            traj_sample_lists = [
                self.agent.get_samples(cond, -self._hyperparams['num_samples'])
                for cond in self._train_idx
            ]
            self.export_samples(traj_sample_lists)

            # Iteration
            with Timer(self.algorithm.timers, 'iteration'):
                self.algorithm.iteration(traj_sample_lists, itr)
            self.export_dynamics()
            self.export_controllers()
            self.export_times()

            # Sample learned policies for visualization

            # LQR policies static resets
            if self._hyperparams['num_lqr_samples_static'] > 0:
                self.export_samples(
                    self._take_policy_samples(
                        N=self._hyperparams['num_lqr_samples_static'],
                        pol=None,
                        rnd=False), '_lqr-static')

            # LQR policies random resets
            if self._hyperparams['num_lqr_samples_random'] > 0:
                self.export_samples(
                    self._take_policy_samples(
                        N=self._hyperparams['num_lqr_samples_random'],
                        pol=None,
                        rnd=True), '_lqr-random')

            if hasattr(self.algorithm, 'policy_opt'):
                # Global policy static resets
                if self._hyperparams['num_pol_samples_static'] > 0:
                    self.export_samples(
                        self._take_policy_samples(
                            N=self._hyperparams['num_pol_samples_static'],
                            pol=self.algorithm.policy_opt.policy,
                            rnd=False), '_pol-static')

                # Global policy static resets
                if self._hyperparams['num_pol_samples_random'] > 0:
                    self.export_samples(
                        self._take_policy_samples(
                            N=self._hyperparams['num_pol_samples_random'],
                            pol=self.algorithm.policy_opt.policy,
                            rnd=True), '_pol-random')

        self._end()