def __init__(self, config, quit_on_end=False):
        """
        Initialize GPSMain
        Args:
            config: Hyperparameters for experiment
            quit_on_end: When true, quit automatically on completion
        """

        self._quit_on_end = quit_on_end
        self._hyperparams = config
        self._conditions = config['common']['conditions']

        if 'train_conditions' in config['common']:
            self._train_idx = config['common']['train_conditions']
            self._test_idx = config['common']['test_conditions']
        else:
            self._train_idx = range(self._conditions)
            config['common']['train_conditions'] = config['common'][
                'conditions']
            self._hyperparams = config
            self._test_idx = self._train_idx

        self._data_files_dir = config['common']['data_files_dir']

        self.agent = config['agent']['type'](config['agent'])

        self.data_logger = DataLogger()

        self.gui = GPSTrainingGUI(
            config['common']) if config['gui_on'] else None

        config['algorithm']['agent'] = self.agent
        self.algorithm = config['algorithm']['type'](config['algorithm'])
예제 #2
0
    def __init__(self, config, quit_on_end=False):
        """
		Initialize GPSMain
		Args:
			config: Hyperparameters for experiment
			quit_on_end: When true, quit automatically on completion
		"""
        self._quit_on_end = quit_on_end
        self._hyperparams = config
        # print(config)
        self._conditions = config['common']['conditions']
        if 'train_conditions' in config['common']:
            self._train_idx = config['common']['train_conditions']
            self._test_idx = config['common']['test_conditions']
        else:
            self._train_idx = range(self._conditions)
            config['common']['train_conditions'] = config['common'][
                'conditions']
            self._hyperparams = config
            self._test_idx = self._train_idx

        self._data_files_dir = config['common']['data_files_dir']
        self.agent = config['agent']['type'](config['agent'])
        self.data_logger = DataLogger()
        self.gui = GPSTrainingGUI(
            config['common']) if config['gui_on'] else None

        config['algorithm']['agent'] = self.agent
        # hard code to pass the map_state and target_state
        config['algorithm']['cost']['costs'][1]['data_types'][3][
            'target_state'] = config['agent']['target_state']
        config['algorithm']['cost']['costs'][1]['data_types'][3][
            'map_size'] = config['agent']['map_size']
        # config['algorithm']['cost']['costs'][1]['data_types'][3]['map_size'] = CUT_MAP_SIZE

        if len(config['algorithm']['cost']['costs']) > 2:
            # temporarily deprecated, not considering collision cost
            # including cost_collision
            config['algorithm']['cost']['costs'][2]['data_types'][3][
                'target_state'] = config['agent']['target_state']
            config['algorithm']['cost']['costs'][2]['data_types'][3][
                'map_size'] = config['agent']['map_size']
            config['algorithm']['cost']['costs'][2]['data_types'][3][
                'map_state'] = config['agent']['map_state']
        # print(config['algorithm'])
        self.algorithm = config['algorithm']['type'](config['algorithm'])

        # Modified by RH
        self.finishing_time = None
        self.U = None
        self.final_pos = None
        self.samples = []
        self.quick_sample = None
        # self.map_size = config['agent']['map_size']
        self.map_size = CUT_MAP_SIZE
        self.display_center = config['agent']['display_center']
예제 #3
0
    def __init__(self, config):
        self._hyperparams = config
        self._conditions = config['common']['conditions']
        self._data_files_dir = config['common']['data_files_dir']

        self.agent = config['agent']['type'](config['agent'])
        self.data_logger = DataLogger()
        self.gui = GPSTrainingGUI(
            config['common']) if config['gui_on'] else None

        config['algorithm']['agent'] = self.agent
        self.algorithm = config['algorithm']['type'](config['algorithm'])
예제 #4
0
파일: gps_main.py 프로젝트: cbfinn/gps
    def __init__(self, config, quit_on_end=False):
        """
        Initialize GPSMain
        Args:
            config: Hyperparameters for experiment
            quit_on_end: When true, quit automatically on completion
        """
        self._quit_on_end = quit_on_end
        self._hyperparams = config
        self._conditions = config['common']['conditions']
        if 'train_conditions' in config['common']:
            self._train_idx = config['common']['train_conditions']
            self._test_idx = config['common']['test_conditions']
        else:
            self._train_idx = range(self._conditions)
            config['common']['train_conditions'] = config['common']['conditions']
            self._hyperparams=config
            self._test_idx = self._train_idx

        self._data_files_dir = config['common']['data_files_dir']

        self.agent = config['agent']['type'](config['agent'])
        self.data_logger = DataLogger()
        self.gui = GPSTrainingGUI(config['common']) if config['gui_on'] else None

        config['algorithm']['agent'] = self.agent
        self.algorithm = config['algorithm']['type'](config['algorithm'])
예제 #5
0
    def __init__(self, config):
        """
        Initialize GPSMain
        Args:
            config: Hyperparameters for experiment
            quit_on_end: When true, quit automatically on completion
        """
        self._quit_on_end = quit_on_end
        self._hyperparams = config
        self._conditions = config['common']['conditions']
        if 'train_conditions' in config['common']:
            self._train_idx = config['common']['train_conditions']
            self._test_idx = config['common']['test_conditions']
        else:
            self._train_idx = range(self._conditions)
            config['common']['train_conditions'] = config['common'][
                'conditions']
            self._hyperparams = config
            self._test_idx = self._train_idx

        self._data_files_dir = config['common']['data_files_dir']

        self.agent = config['agent']['type'](config['agent'])
        self.data_logger = DataLogger()
        self.gui = GPSTrainingGUI(
            config['common']) if config['gui_on'] else None

        config['algorithm']['agent'] = self.agent
        self.algorithm = config['algorithm']['type'](config['algorithm'])
        self.algorithm.init_samples = self._hyperparams['num_samples']

        if self.algorithm._hyperparams['ioc']:
            demo_file = self._data_files_dir + 'demos.pkl'
            demos = self.data_logger.unpickle(demo_file)
            if demos is None:
                self.demo_gen = GenDemo(config)
                self.demo_gen.ioc_algo = self.algorithm
                self.demo_gen.generate()
                demo_file = self._data_files_dir + 'demos.pkl'
                demos = self.data_logger.unpickle(demo_file)
            config['agent']['pos_body_offset'] = demos['pos_body_offset']
            self.agent = config['agent']['type'](config['agent'])
            self.algorithm.demoX = demos['demoX']
            self.algorithm.demoU = demos['demoU']
            self.algorithm.demoO = demos['demoO']
예제 #6
0
    def __init__(self, config):
        self._hyperparams = config
        self._conditions = config['common']['conditions']
        self._data_files_dir = config['common']['data_files_dir']

        self.agent = config['agent']['type'](config['agent'])
        self.data_logger = DataLogger()
        self.gui = GPSTrainingGUI(config['common']) if config['gui_on'] else None

        config['algorithm']['agent'] = self.agent
        self.algorithm = config['algorithm']['type'](config['algorithm'])
예제 #7
0
파일: gps_main.py 프로젝트: rsdk/inpuls
    def __init__(self, config, quit_on_end=False, no_algorithm=False):
        """
        Initialize GPSMain
        Args:
            config: Hyperparameters for experiment
            quit_on_end: When true, quit automatically on completion
        """
        self._quit_on_end = quit_on_end
        self._hyperparams = config
        self._conditions = config['common']['conditions']
        if 'train_conditions' in config['common']:
            self._train_idx = config['common']['train_conditions']
            self._test_idx = config['common']['test_conditions']
        else:
            self._train_idx = range(self._conditions)
            config['common']['train_conditions'] = config['common'][
                'conditions']
            self._hyperparams = config
            self._test_idx = self._train_idx

        self._data_files_dir = config['common']['data_files_dir']
        config['agent']['data_files_dir'] = self._data_files_dir
        config['algorithm']['data_files_dir'] = self._data_files_dir

        self.agent = config['agent']['type'](config['agent'])
        self.data_logger = DataLogger()
        self.gui = None
        if config['gui_on']:
            from gps.gui.gps_training_gui import GPSTrainingGUI  # Only import if neccessary
            self.gui = GPSTrainingGUI(config['common'])
        self.mode = None

        config['algorithm']['agent'] = self.agent
        if not no_algorithm:
            self.algorithm = config['algorithm']['type'](config['algorithm'])
            self.algorithm._data_files_dir = self._data_files_dir
            if hasattr(self.algorithm, 'policy_opt'):
                self.algorithm.policy_opt._data_files_dir = self._data_files_dir

        self.session_id = None
예제 #8
0
 def __init__(self, config, quit_on_end=False):
     """
     Initialize GPSMain
     Args:
         config: Hyperparameters for experiment
         quit_on_end: When true, quit automatically on completion
     """
     self._quit_on_end = quit_on_end
     self._hyperparams = config
     self._conditions = config['common']['conditions']
     #self._condition = 1
     if 'train_conditions' in config['common']:
         #False
         self._train_idx = config['common']['train_conditions']
         self._test_idx = config['common']['test_conditions']
     else:
         self._train_idx = range(self._conditions)
         config['common']['train_conditions'] = config['common'][
             'conditions']
         #create a new key in the dictionary common and assign the value 1
         self._hyperparams = config
         #reinitiallizing the hyperparameters because the config was changed
         self._test_idx = self._train_idx
         #getting hte train index again
     self._data_files_dir = config['common']['data_files_dir']
     #getting the data file path from which is stored in the common dic
     self.agent = config['agent']['type'](config['agent'])
     #here it creat the object from the agent directory
     #print(self.agent,'self.agent')
     self.data_logger = DataLogger()
     #here the gui files leads to the
     self.gui = GPSTrainingGUI(
         config['common']) if config['gui_on'] else None
     #again with they change the config file now adding object to the dic
     config['algorithm']['agent'] = self.agent
     self.algorithm = config['algorithm']['type'](config['algorithm'])
예제 #9
0
    def __init__(self, config, quit_on_end=False):
        """
        Initialize GPSMain
        Args:
            config: Hyperparameters for experiment
            quit_on_end: When true, quit automatically on completion
        """
        self.start_time = timeit.default_timer()
        self._quit_on_end = quit_on_end
        self._hyperparams = config
        self._conditions = config['common']['conditions']
        if 'train_conditions' in config['common']:
            self._train_idx = config['common']['train_conditions']
            self._test_idx = config['common']['test_conditions']
        else:
            self._train_idx = list(range(self._conditions))
            config['common']['train_conditions'] = config['common'][
                'conditions']
            self._hyperparams = config
            self._test_idx = self._train_idx

        self._data_files_dir = config['common']['data_files_dir']

        self.agent = config['agent']['type'](config['agent'])
        self.data_logger = DataLogger()
        self.gui = GPSTrainingGUI(
            config['common']) if config['gui_on'] else None

        config['algorithm']['agent'] = self.agent
        self.algorithm = config['algorithm']['type'](config['algorithm'])

        #CB save image of cost w/ gui
        if not config['gui_on']:
            self.simplePlotter = SimplePlotter(
                config['common']['experiment_name'],
                config['common']['data_files_dir'])
예제 #10
0
class GPSMain(object):
    """ Main class to run algorithms and experiments. """
    def __init__(self, config, quit_on_end=False):
        """
        Initialize GPSMain
        Args:
            config: Hyperparameters for experiment
            quit_on_end: When true, quit automatically on completion
        """

        self.config = config
        self._quit_on_end = quit_on_end
        self._hyperparams = config
        self._conditions = config['common']['conditions']
        if 'train_conditions' in config['common']:
            self._train_idx = config['common']['train_conditions']
            self._test_idx = config['common']['test_conditions']
        else:
            self._train_idx = range(self._conditions)
            config['common']['train_conditions'] = config['common'][
                'conditions']
            self._hyperparams = config
            self._test_idx = self._train_idx

        self._data_files_dir = config['common']['data_files_dir']

        self.agent = config['agent']['type'](config['agent'])
        self.data_logger = DataLogger()
        self.gui = GPSTrainingGUI(
            config['common']) if config['gui_on'] else None

        config['algorithm']['agent'] = self.agent
        self.algorithm = config['algorithm']['type'](config['algorithm'])

    def run(self, itr_load=None, parallel=False):
        """
        Run training by iteratively sampling and taking an iteration.
        Args:
            itr_load: If specified, loads algorithm state from that
                iteration, and resumes training at the next iteration.
        Returns: None
        """
        timestamps = {
            "start": time.time(),
            "local_controller_times": [],
            "iteration_times": [],
            "end": None
        }
        try:
            if parallel == False:
                itr_start = self._initialize(itr_load)
                for itr in range(itr_start, self._hyperparams['iterations']):
                    iteration_start = time.time()
                    cond_start = time.time()
                    for cond in self._train_idx:
                        for i in range(self._hyperparams['num_samples']):
                            self._take_sample(itr, cond, i)
                    traj_sample_lists = [
                        self.agent.get_samples(
                            cond, -self._hyperparams['num_samples'])
                        for cond in self._train_idx
                    ]

                    cond_end = time.time()
                    timestamps["local_controller_times"].append(cond_end -
                                                                cond_start)
                    print(f"Controller time: {cond_end-cond_start}")
                    # Clear agent samples.
                    self.agent.clear_samples()

                    self._take_iteration(itr, traj_sample_lists)
                    pol_sample_lists = self._take_policy_samples(itr)
                    self._log_data(itr, traj_sample_lists, pol_sample_lists)
                    iteration_end = time.time()
                    timestamps["iteration_times"].append(iteration_end -
                                                         iteration_start)
                    print(f'Iteration time: {iteration_end-iteration_start}')
            else:
                itr_start = self._initialize(itr_load)
                for itr in range(itr_start, self._hyperparams['iterations']):
                    iteration_start = time.time()
                    jobs = self._train_idx
                    filenames = ["temp{k}.pkl" for k, _ in enumerate(jobs)]
                    for f in filenames:
                        with open(f, 'wb') as outfile:
                            pickle.dump(self.algorithm, outfile)
                    config = dict()
                    config['agent'] = self._hyperparams['agent']
                    config['num_samples'] = self._hyperparams['num_samples']
                    config['verbose_trials'] = self._hyperparams[
                        'verbose_trials']
                    localargs = [(cond, itr, config, filenames[i])
                                 for i, cond in enumerate(jobs)]
                    # run_local_controller(*localargs[0])
                    cond_start = time.time()
                    #set the number of workers, up to the number of conditions
                    print(f"Running local rollouts with {parallel} workers!")
                    with multiprocessing.Pool(processes=min(
                            parallel, len(self._train_idx))) as pool:
                        results = pool.starmap(run_local_controller, localargs)
                    cond_end = time.time()
                    timestamps["local_controller_times"].append(cond_end -
                                                                cond_start)
                    # traj_sample_lists = [
                    #     agent.get_samples(0, -self._hyperparams['num_samples'])
                    #     for agent in results
                    # ]
                    traj_sample_lists = results
                    for traj in traj_sample_lists:
                        for sample in traj:
                            sample.agent = self.agent  # add the default agent so it's not none
                    # Clear agent samples.
                    self.agent.clear_samples()
                    self._take_iteration(itr, traj_sample_lists)
                    pol_sample_lists = self._take_policy_samples(itr)
                    self._log_data(itr, traj_sample_lists, pol_sample_lists)
                    iteration_end = time.time()
                    timestamps["iteration_times"].append(iteration_end -
                                                         iteration_start)
            timestamps["end"] = time.time()
            with open(
                    os.path.join(self._hyperparams['common']['data_files_dir'],
                                 'timestamps.json'), 'w') as outfile:
                json.dump(timestamps, outfile)
            '''
            itr_start = self._initialize(itr_load)

            # import pdb; pdb.set_trace

            for itr in range(itr_start, self._hyperparams['iterations']):
                jobs = self._train_idx
                # for cond in self._train_idx:
                # test = copy.copy(self.algorithm)
                # test = deepcopy(self.agent)
                # algorithm_file = self._data_files_dir + 'algorithm_itr_%02d.pkl' % itr
                config = self.config
                filenames = [f"temp{k}" for k,_ in enumerate(jobs)]
                
                agents = [config['agent'] for _ in range(len(jobs))]
                deg_obs = config['agent']['sensor_dims'][RGB_IMAGE]
                deg_action = config['agent']['sensor_dims'][ACTION]
                for f in filenames:
                    self.algorithm.policy_opt.policy.pickle_policy(deg_obs, deg_action, f, should_hash=False)
                    # with open(f, 'wb') as tempalgfile:
                    #     pickle.dump(self.algorithm.policy_opt.policy, tempalgfile)
                policy_filenames = [f + '/_pol' for f in filenames]
                
                localargs = [(cond, self._hyperparams['num_samples'], itr, policy_filenames[j], agents[j], self._hyperparams['algorithm']['policy_opt'], self._hyperparams['verbose_trials']) for j,cond in enumerate(jobs)]
                run_local_controller(*localargs[0])
                import pdb; pdb.set_trace()
                with multiprocessing.Pool(processes = min(6,len(self._train_idx))) as pool:
                    results = pool.starmap(run_local_controller, localargs)
                # import pdb; pdb.set_trace()
                self.agent._samples = results
                traj_sample_lists = [
                    self.agent.get_samples(cond, -self._hyperparams['num_samples'])
                    for cond in self._train_idx
                ]

                # Clear agent samples.
                self.agent.clear_samples()

                self._take_iteration(itr, traj_sample_lists)
                pol_sample_lists = self._take_policy_samples()
                self._log_data(itr, traj_sample_lists, pol_sample_lists)
                '''

        except Exception as e:
            traceback.print_exception(*sys.exc_info())
        finally:
            self._end()

    def test_policy(self, itr, N):
        """
        Take N policy samples of the algorithm state at iteration itr,
        for testing the policy to see how it is behaving.
        (Called directly from the command line --policy flag).
        Args:
            itr: the iteration from which to take policy samples
            N: the number of policy samples to take
        Returns: None
        """
        algorithm_file = self._data_files_dir + 'algorithm_itr_%02d.pkl' % itr
        self.algorithm = self.data_logger.unpickle(algorithm_file)
        if self.algorithm is None:
            print("Error: cannot find '%s.'" % algorithm_file)
            os._exit(1)  # called instead of sys.exit(), since t
        traj_sample_lists = self.data_logger.unpickle(
            self._data_files_dir + ('traj_sample_itr_%02d.pkl' % itr))

        pol_sample_lists = self._take_policy_samples(itr, N)
        self.data_logger.pickle(
            self._data_files_dir + ('pol_sample_itr_%02d.pkl' % itr),
            copy.copy(pol_sample_lists))

        if self.gui:
            self.gui.update(itr, self.algorithm, self.agent, traj_sample_lists,
                            pol_sample_lists)
            self.gui.set_status_text(
                ('Took %d policy sample(s) from ' +
                 'algorithm state at iteration %d.\n' +
                 'Saved to: data_files/pol_sample_itr_%02d.pkl.\n') %
                (N, itr, itr))

    def _initialize(self, itr_load):
        """
        Initialize from the specified iteration.
        Args:
            itr_load: If specified, loads algorithm state from that
                iteration, and resumes training at the next iteration.
        Returns:
            itr_start: Iteration to start from.
        """
        if itr_load is None:
            if self.gui:
                self.gui.set_status_text('Press \'go\' to begin.')
            return 0
        else:
            self.itr_load = itr_load
            algorithm_file = self._data_files_dir + 'algorithm_itr_%02d.pkl' % itr_load
            self.algorithm = self.data_logger.unpickle(algorithm_file)
            if self.algorithm is None:
                print("Error: cannot find '%s.'" % algorithm_file)
                os._exit(
                    1
                )  # called instead of sys.exit(), since this is in a thread

            if self.gui:
                traj_sample_lists = self.data_logger.unpickle(
                    self._data_files_dir +
                    ('traj_sample_itr_%02d.pkl' % itr_load))
                if self.algorithm.cur[0].pol_info:
                    pol_sample_lists = self.data_logger.unpickle(
                        self._data_files_dir +
                        ('pol_sample_itr_%02d.pkl' % itr_load))
                else:
                    pol_sample_lists = None
                self.gui.set_status_text((
                    'Resuming training from algorithm state at iteration %d.\n'
                    + 'Press \'go\' to begin.') % itr_load)
            return itr_load + 1

    def _take_sample(self, itr, cond, i):
        """
        Collect a sample from the agent.
        Args:
            itr: Iteration number.
            cond: Condition number.
            i: Sample number.
        Returns: None
        """
        if self.algorithm._hyperparams['sample_on_policy'] \
                and self.algorithm.iteration_count > 0:
            pol = self.algorithm.policy_opt.policy
        else:
            pol = self.algorithm.cur[cond].traj_distr
        if self.gui:
            self.gui.set_image_overlays(cond)  # Must call for each new cond.
            redo = True
            while redo:
                while self.gui.mode in ('wait', 'request', 'process'):
                    if self.gui.mode in ('wait', 'process'):
                        time.sleep(0.01)
                        continue
                    # 'request' mode.
                    if self.gui.request == 'reset':
                        try:
                            self.agent.reset(cond)
                        except NotImplementedError:
                            self.gui.err_msg = 'Agent reset unimplemented.'
                    elif self.gui.request == 'fail':
                        self.gui.err_msg = 'Cannot fail before sampling.'
                    self.gui.process_mode()  # Complete request.

                self.gui.set_status_text(
                    'Sampling: iteration %d, condition %d, sample %d.' %
                    (itr, cond, i))
                self.agent.sample(
                    pol,
                    cond,
                    itr,
                    verbose=(i < self._hyperparams['verbose_trials']))

                if self.gui.mode == 'request' and self.gui.request == 'fail':
                    redo = True
                    self.gui.process_mode()
                    self.agent.delete_last_sample(cond)
                else:
                    redo = False
        else:
            self.agent.sample(
                pol,
                cond,
                itr,
                verbose=(i < self._hyperparams['verbose_trials']))

    def _take_iteration(self, itr, sample_lists):
        """
        Take an iteration of the algorithm.
        Args:
            itr: Iteration number.
        Returns: None
        """
        if self.gui:
            self.gui.set_status_text('Calculating.')
            self.gui.start_display_calculating()
        self.algorithm.iteration(sample_lists)
        if self.gui:
            self.gui.stop_display_calculating()

    def _take_policy_samples(self, itr, N=None):
        """
        Take samples from the policy to see how it's doing.
        Args:
            N  : number of policy samples to take per condition
        Returns: None
        """
        if 'verbose_policy_trials' not in self._hyperparams:
            # AlgorithmTrajOpt
            return None
        verbose = self._hyperparams['verbose_policy_trials']
        if self.gui:
            self.gui.set_status_text('Taking policy samples.')
        pol_samples = [[None] for _ in range(len(self._test_idx))]
        # Since this isn't noisy, just take one sample.
        # TODO: Make this noisy? Add hyperparam?
        # TODO: Take at all conditions for GUI?
        for cond in range(len(self._test_idx)):
            pol_samples[cond][0] = self.agent.sample(
                self.algorithm.policy_opt.policy,
                self._test_idx[cond],
                itr,
                verbose=verbose,
                save=False,
                noisy=False)
        return [SampleList(samples) for samples in pol_samples]

    def _log_data(self, itr, traj_sample_lists, pol_sample_lists=None):
        """
        Log data and algorithm, and update the GUI.
        Args:
            itr: Iteration number.
            traj_sample_lists: trajectory samples as SampleList object
            pol_sample_lists: policy samples as SampleList object
        Returns: None
        """
        if self.gui:
            self.gui.set_status_text('Logging data and updating GUI.')
            self.gui.update(itr, self.algorithm, self.agent, traj_sample_lists,
                            pol_sample_lists)
            self.gui.save_figure(self._data_files_dir +
                                 ('figure_itr_%02d.png' % itr))
        if 'no_sample_logging' in self._hyperparams['common']:
            return
        self.data_logger.pickle(
            self._data_files_dir + ('algorithm_itr_%02d.pkl' % itr),
            copy.copy(self.algorithm))
        self.data_logger.pickle(
            self._data_files_dir + ('traj_sample_itr_%02d.pkl' % itr),
            copy.copy(traj_sample_lists))
        if pol_sample_lists:
            self.data_logger.pickle(
                self._data_files_dir + ('pol_sample_itr_%02d.pkl' % itr),
                copy.copy(pol_sample_lists))

    def _end(self):
        """ Finish running and exit. """
        if self.gui:
            self.gui.set_status_text('Training complete.')
            self.gui.end_mode()
            if self._quit_on_end:
                # Quit automatically (for running sequential expts)
                os._exit(1)
예제 #11
0
class GPSMain(object):
    """ Main class to run algorithms and experiments. """
    def __init__(self, config, quit_on_end=False):
        """
        Initialize GPSMain
        Args:
            config: Hyperparameters for experiment
            quit_on_end: When true, quit automatically on completion
        """
        self._quit_on_end = quit_on_end
        self._hyperparams = config
        self._conditions = config['common']['conditions']
        if 'train_conditions' in config['common']:
            self._train_idx = config['common']['train_conditions']
            self._test_idx = config['common']['test_conditions']
        else:
            self._train_idx = range(self._conditions)
            config['common']['train_conditions'] = config['common']['conditions']
            self._hyperparams=config
            self._test_idx = self._train_idx

        self._data_files_dir = config['common']['data_files_dir']

        self.agent = config['agent']['type'](config['agent'])
        self.data_logger = DataLogger()
        self.gui = GPSTrainingGUI(config['common']) if config['gui_on'] else None

        config['algorithm']['agent'] = self.agent
        self.algorithm = config['algorithm']['type'](config['algorithm'])

    def run(self, itr_load=None):
        """
        Run training by iteratively sampling and taking an iteration.
        Args:
            itr_load: If specified, loads algorithm state from that
                iteration, and resumes training at the next iteration.
        Returns: None
        """
        print('GPS_Main Running')
        try:
            itr_start = self._initialize(itr_load)

            for itr in range(itr_start, self._hyperparams['iterations']):
                for cond in self._train_idx:
                    for i in range(self._hyperparams['num_samples']):
                        print('Taking Sample %i/%i' % (i+1, self._hyperparams['num_samples']))
                        self._take_sample(itr, cond, i)

                traj_sample_lists = [
                    self.agent.get_samples(cond, -self._hyperparams['num_samples'])
                    for cond in self._train_idx
                ]

                # Clear agent samples.
                self.agent.clear_samples()

                # tf_sess = self.algorithm.policy_opt.sess
                # print(tf_sess.graph.get_tensor_by_name('w_1:0').eval(session=tf_sess))
                self._take_iteration(itr, traj_sample_lists)
                # print(tf_sess.graph.get_tensor_by_name('w_1:0').eval(session=tf_sess))

                # Save tensorflow NN weights as csv
                print('Saving NN weights')
                tf_sess = self.algorithm.policy_opt.sess
                policy_opt_params = self.algorithm.policy_opt._hyperparams
                with tf_sess.graph.as_default():
                    for i in range(policy_opt_params['network_params']['n_layers']+1):
                        kernel = tf_sess.graph.get_tensor_by_name('w_'+str(i)+':0').eval(session=tf_sess)
                        bias = tf_sess.graph.get_tensor_by_name('b_'+str(i)+':0').eval(session=tf_sess)
                        # print(kernel.shape)
                        # print(bias.shape)
                        weights = np.vstack((kernel,bias.reshape(1,len(bias))))
                        # print(weights.shape)
                        filename = self._data_files_dir + ('pol_wgts_itr_%d_l_%d.csv' % (itr, i))
                        np.savetxt(filename,weights,delimiter=',')
                if itr == 0:
                    print('Saving normalization statistics')
                    scale = np.diag(self.algorithm.policy_opt.policy.scale)
                    bias = self.algorithm.policy_opt.policy.bias
                    print(scale.shape)
                    print(bias.shape)
                    np.savetxt(self._data_files_dir +'nn_in_scale.csv', scale, delimiter=',')
                    np.savetxt(self._data_files_dir +'nn_in_bias.csv', bias, delimiter=',')

                pol_sample_lists = self._take_policy_samples()
                # print('policy sample lists: ',pol_sample_lists)
                self._log_data(itr, traj_sample_lists, pol_sample_lists)

                # Calculate average costs from policy samples and print results
                costs = [np.mean(np.sum(self.algorithm.prev[m].cs, axis=1)) for m in range(self.algorithm.M)]
                self._print_pol_sample_results(itr, self.algorithm, costs, pol_sample_lists)

        except Exception as e:
            traceback.print_exception(*sys.exc_info())
        finally:
            self._end()

    def test_policy(self, itr, N):
        """
        Take N policy samples of the algorithm state at iteration itr,
        for testing the policy to see how it is behaving.
        (Called directly from the command line --policy flag).
        Args:
            itr: the iteration from which to take policy samples
            N: the number of policy samples to take
        Returns: None
        """
        algorithm_file = self._data_files_dir + 'algorithm_itr_%02d.pkl' % itr
        self.algorithm = self.data_logger.unpickle(algorithm_file)
        if self.algorithm is None:
            print("Error: cannot find '%s.'" % algorithm_file)
            os._exit(1) # called instead of sys.exit(), since t
        traj_sample_lists = self.data_logger.unpickle(self._data_files_dir +
            ('traj_sample_itr_%02d.pkl' % itr))

        pol_sample_lists = self._take_policy_samples(N)
        self.data_logger.pickle(
            self._data_files_dir + ('pol_sample_itr_%02d.pkl' % itr),
            copy.copy(pol_sample_lists)
        )

        if self.gui:
            self.gui.update(itr, self.algorithm, self.agent,
                traj_sample_lists, pol_sample_lists)
            self.gui.set_status_text(('Took %d policy sample(s) from ' +
                'algorithm state at iteration %d.\n' +
                'Saved to: data_files/pol_sample_itr_%02d.pkl.\n') % (N, itr, itr))

    def _print_pol_sample_results(self, itr, algorithm, costs, pol_sample_lists):
        if isinstance(algorithm, AlgorithmMDGPS) or isinstance(algorithm, AlgorithmBADMM):
            condition_titles = '%3s | %8s %12s' % ('', '', '')
            itr_data_fields  = '%3s | %8s %12s' % ('itr', 'avg_cost', 'avg_pol_cost')
        else:
            condition_titles = '%3s | %8s' % ('', '')
            itr_data_fields  = '%3s | %8s' % ('itr', 'avg_cost')
        for m in range(algorithm.M):
            condition_titles += ' | %8s %9s %-7d' % ('', 'condition', m)
            itr_data_fields  += ' | %8s %8s %8s' % ('  cost  ', '  step  ', 'entropy ')
            if isinstance(algorithm, AlgorithmBADMM):
                condition_titles += ' %8s %8s %8s' % ('', '', '')
                itr_data_fields  += ' %8s %8s %8s' % ('pol_cost', 'kl_div_i', 'kl_div_f')
            elif isinstance(algorithm, AlgorithmMDGPS):
                condition_titles += ' %8s' % ('')
                itr_data_fields  += ' %8s' % ('pol_cost')
        print(condition_titles)
        print(itr_data_fields)

        avg_cost = np.mean(costs)
        if pol_sample_lists is not None:
            test_idx = algorithm._hyperparams['test_conditions']
            # pol_sample_lists is a list of singletons
            samples = [sl[0] for sl in pol_sample_lists]
            pol_costs = [np.sum(algorithm.cost[idx].eval(s)[0])
                    for s, idx in zip(samples, test_idx)]
            itr_data = '%3d | %8.2f %12.2f' % (itr, avg_cost, np.mean(pol_costs))
        else:
            itr_data = '%3d | %8.2f' % (itr, avg_cost)
        for m in range(algorithm.M):
            cost = costs[m]
            step = np.mean(algorithm.prev[m].step_mult * algorithm.base_kl_step)
            entropy = 2*np.sum(np.log(np.diagonal(algorithm.prev[m].traj_distr.chol_pol_covar,
                    axis1=1, axis2=2)))
            itr_data += ' | %8.2f %8.2f %8.2f' % (cost, step, entropy)
            if isinstance(algorithm, AlgorithmBADMM):
                kl_div_i = algorithm.cur[m].pol_info.init_kl.mean()
                kl_div_f = algorithm.cur[m].pol_info.prev_kl.mean()
                itr_data += ' %8.2f %8.2f %8.2f' % (pol_costs[m], kl_div_i, kl_div_f)
            elif isinstance(algorithm, AlgorithmMDGPS):
                # TODO: Change for test/train better.
                if test_idx == algorithm._hyperparams['train_conditions']:
                    itr_data += ' %8.2f' % (pol_costs[m])
                else:
                    itr_data += ' %8s' % ("N/A")
        print(itr_data)

    def _initialize(self, itr_load):
        """
        Initialize from the specified iteration.
        Args:
            itr_load: If specified, loads algorithm state from that
                iteration, and resumes training at the next iteration.
        Returns:
            itr_start: Iteration to start from.
        """
        if itr_load is None:
            if self.gui:
                self.gui.set_status_text('Press \'go\' to begin.')
            return 0
        else:
            algorithm_file = self._data_files_dir + 'algorithm_itr_%02d.pkl' % itr_load
            self.algorithm = self.data_logger.unpickle(algorithm_file)
            if self.algorithm is None:
                print("Error: cannot find '%s.'" % algorithm_file)
                os._exit(1) # called instead of sys.exit(), since this is in a thread

            if self.gui:
                traj_sample_lists = self.data_logger.unpickle(self._data_files_dir +
                    ('traj_sample_itr_%02d.pkl' % itr_load))
                if self.algorithm.cur[0].pol_info:
                    pol_sample_lists = self.data_logger.unpickle(self._data_files_dir +
                        ('pol_sample_itr_%02d.pkl' % itr_load))
                else:
                    pol_sample_lists = None
                self.gui.set_status_text(
                    ('Resuming training from algorithm state at iteration %d.\n' +
                    'Press \'go\' to begin.') % itr_load)
            return itr_load + 1

    def _take_sample(self, itr, cond, i):
        """
        Collect a sample from the agent.
        Args:
            itr: Iteration number.
            cond: Condition number.
            i: Sample number.
        Returns: None
        """
        if self.algorithm._hyperparams['sample_on_policy'] \
                and self.algorithm.iteration_count > 0:
            pol = self.algorithm.policy_opt.policy
        else:
            pol = self.algorithm.cur[cond].traj_distr
        if self.gui:
            self.gui.set_image_overlays(cond)   # Must call for each new cond.
            redo = True
            while redo:
                while self.gui.mode in ('wait', 'request', 'process'):
                    if self.gui.mode in ('wait', 'process'):
                        time.sleep(0.01)
                        continue
                    # 'request' mode.
                    if self.gui.request == 'reset':
                        try:
                            self.agent.reset(cond)
                        except NotImplementedError:
                            self.gui.err_msg = 'Agent reset unimplemented.'
                    elif self.gui.request == 'fail':
                        self.gui.err_msg = 'Cannot fail before sampling.'
                    self.gui.process_mode()  # Complete request.

                self.gui.set_status_text(
                    'Sampling: iteration %d, condition %d, sample %d.' %
                    (itr, cond, i)
                )
                # print('taking sample inside _take_sample')
                self.agent.sample(
                    pol, cond,
                    verbose=(i < self._hyperparams['verbose_trials'])
                )

                if self.gui.mode == 'request' and self.gui.request == 'fail':
                    redo = True
                    self.gui.process_mode()
                    self.agent.delete_last_sample(cond)
                else:
                    redo = False
        else:
            self.agent.sample(
                pol, cond,
                verbose=(i < self._hyperparams['verbose_trials']),
                noisy = False
            )

    def _take_iteration(self, itr, sample_lists):
        """
        Take an iteration of the algorithm.
        Args:
            itr: Iteration number.
        Returns: None
        """
        print('Taking dynamics and policy iteration %i' % itr)
        if self.gui:
            self.gui.set_status_text('Calculating.')
            self.gui.start_display_calculating()
        self.algorithm.iteration(sample_lists)
        if self.gui:
            self.gui.stop_display_calculating()

    def _take_policy_samples(self, N=None):
        """
        Take samples from the policy to see how it's doing.
        Args:
            N  : number of policy samples to take per condition
        Returns: None
        """
        if 'verbose_policy_trials' not in self._hyperparams:
            # AlgorithmTrajOpt
            return None
        verbose = self._hyperparams['verbose_policy_trials']
        # print('check2')
        if self.gui:
            # print('taking policy samples')
            self.gui.set_status_text('Taking policy samples.')
        pol_samples = [[None] for _ in range(len(self._test_idx))]
        # Since this isn't noisy, just take one sample.
        # TODO: Make this noisy? Add hyperparam?
        # TODO: Take at all conditions for GUI?
        for cond in range(len(self._test_idx)):
            pol_samples[cond][0] = self.agent.sample(
                self.algorithm.policy_opt.policy, self._test_idx[cond],
                verbose=verbose, save=False, noisy=False)
        return [SampleList(samples) for samples in pol_samples]

    def _log_data(self, itr, traj_sample_lists, pol_sample_lists=None):
        """
        Log data and algorithm, and update the GUI.
        Args:
            itr: Iteration number.
            traj_sample_lists: trajectory samples as SampleList object
            pol_sample_lists: policy samples as SampleList object
        Returns: None
        """
        if self.gui:
            self.gui.set_status_text('Logging data and updating GUI.')
            self.gui.update(itr, self.algorithm, self.agent,
                traj_sample_lists, pol_sample_lists)
            self.gui.save_figure(
                self._data_files_dir + ('figure_itr_%02d.png' % itr)
            )
        if 'no_sample_logging' in self._hyperparams['common']:
            return
        self.data_logger.pickle(
            self._data_files_dir + ('algorithm_itr_%02d.pkl' % itr),
            copy.copy(self.algorithm)
        )
        self.data_logger.pickle(
            self._data_files_dir + ('traj_sample_itr_%02d.pkl' % itr),
            copy.copy(traj_sample_lists)
        )
        if pol_sample_lists:
            self.data_logger.pickle(
                self._data_files_dir + ('pol_sample_itr_%02d.pkl' % itr),
                copy.copy(pol_sample_lists)
            )

    def _end(self):
        """ Finish running and exit. """
        if self.gui:
            self.gui.set_status_text('Training complete.')
            self.gui.end_mode()
            if self._quit_on_end:
                # Quit automatically (for running sequential expts)
                os._exit(1)
예제 #12
0
파일: gps_main.py 프로젝트: rsdk/inpuls
class GPSMain(object):
    """ Main class to run algorithms and experiments. """
    def __init__(self, config, quit_on_end=False, no_algorithm=False):
        """
        Initialize GPSMain
        Args:
            config: Hyperparameters for experiment
            quit_on_end: When true, quit automatically on completion
        """
        self._quit_on_end = quit_on_end
        self._hyperparams = config
        self._conditions = config['common']['conditions']
        if 'train_conditions' in config['common']:
            self._train_idx = config['common']['train_conditions']
            self._test_idx = config['common']['test_conditions']
        else:
            self._train_idx = range(self._conditions)
            config['common']['train_conditions'] = config['common'][
                'conditions']
            self._hyperparams = config
            self._test_idx = self._train_idx

        self._data_files_dir = config['common']['data_files_dir']
        config['agent']['data_files_dir'] = self._data_files_dir
        config['algorithm']['data_files_dir'] = self._data_files_dir

        self.agent = config['agent']['type'](config['agent'])
        self.data_logger = DataLogger()
        self.gui = None
        if config['gui_on']:
            from gps.gui.gps_training_gui import GPSTrainingGUI  # Only import if neccessary
            self.gui = GPSTrainingGUI(config['common'])
        self.mode = None

        config['algorithm']['agent'] = self.agent
        if not no_algorithm:
            self.algorithm = config['algorithm']['type'](config['algorithm'])
            self.algorithm._data_files_dir = self._data_files_dir
            if hasattr(self.algorithm, 'policy_opt'):
                self.algorithm.policy_opt._data_files_dir = self._data_files_dir

        self.session_id = None

    def run(self, session_id, itr_load=None):
        """
        Run training by iteratively sampling and taking an iteration.
        Args:
            itr_load: If specified, loads algorithm state from that
                iteration, and resumes training at the next iteration.
        Returns: None
        """
        itr_start = self._initialize(itr_load)
        if session_id is not None:
            self.session_id = session_id

        for itr in range(itr_start, self._hyperparams['iterations']):
            self.iteration_count = itr
            if hasattr(self.algorithm, 'policy_opt'):
                self.algorithm.policy_opt.iteration_count = itr

            print("*** Iteration %02d ***" % itr)
            # Take trajectory samples
            with Timer(self.algorithm.timers, 'sampling'):
                for cond in self._train_idx:
                    for i in trange(self._hyperparams['num_samples'],
                                    desc='Taking samples'):
                        self._take_sample(itr, cond, i)
            traj_sample_lists = [
                self.agent.get_samples(cond, -self._hyperparams['num_samples'])
                for cond in self._train_idx
            ]
            self.export_samples(traj_sample_lists)

            # Iteration
            with Timer(self.algorithm.timers, 'iteration'):
                self.algorithm.iteration(traj_sample_lists, itr)
            self.export_dynamics()
            self.export_controllers()
            self.export_times()

            # Sample learned policies for visualization

            # LQR policies static resets
            if self._hyperparams['num_lqr_samples_static'] > 0:
                self.export_samples(
                    self._take_policy_samples(
                        N=self._hyperparams['num_lqr_samples_static'],
                        pol=None,
                        rnd=False), '_lqr-static')

            # LQR policies random resets
            if self._hyperparams['num_lqr_samples_random'] > 0:
                self.export_samples(
                    self._take_policy_samples(
                        N=self._hyperparams['num_lqr_samples_random'],
                        pol=None,
                        rnd=True), '_lqr-random')

            if hasattr(self.algorithm, 'policy_opt'):
                # Global policy static resets
                if self._hyperparams['num_pol_samples_static'] > 0:
                    self.export_samples(
                        self._take_policy_samples(
                            N=self._hyperparams['num_pol_samples_static'],
                            pol=self.algorithm.policy_opt.policy,
                            rnd=False), '_pol-static')

                # Global policy static resets
                if self._hyperparams['num_pol_samples_random'] > 0:
                    self.export_samples(
                        self._take_policy_samples(
                            N=self._hyperparams['num_pol_samples_random'],
                            pol=self.algorithm.policy_opt.policy,
                            rnd=True), '_pol-random')

        self._end()

    def test_policy(self, itr, N, reset_cond=None):
        """
        Take N policy samples of the algorithm state at iteration itr,
        for testing the policy to see how it is behaving.
        (Called directly from the command line --policy flag).
        Args:
            itr: the iteration from which to take policy samples
            N: the number of policy samples to take
        Returns: None
        """
        algorithm_file = self._data_files_dir + 'algorithm_itr_%02d.pkl' % itr
        self.algorithm = self.data_logger.unpickle(algorithm_file)
        if self.algorithm is None:
            print("Error: cannot find '%s.'" % algorithm_file)
            os._exit(1)  # called instead of sys.exit()
        traj_sample_lists = self.data_logger.unpickle(
            self._data_files_dir + ('traj_sample_itr_%02d.pkl' % itr))

        pol_sample_lists = self._take_policy_samples(N, reset_cond=reset_cond)
        self.data_logger.pickle(
            self._data_files_dir + ('pol_sample_itr_%02d.pkl' % itr),
            copy.copy(pol_sample_lists))

        if self.gui:
            self.gui.update(itr, self.algorithm, self.agent, traj_sample_lists,
                            pol_sample_lists)
            self.gui.set_status_text(
                ('Took %d policy sample(s) from ' +
                 'algorithm state at iteration %d.\n' +
                 'Saved to: data_files/pol_sample_itr_%02d.pkl.\n') %
                (N, itr, itr))

    def _initialize(self, itr_load):
        """
        Initialize from the specified iteration.
        Args:
            itr_load: If specified, loads algorithm state from that
                iteration, and resumes training at the next iteration.
        Returns:
            itr_start: Iteration to start from.
        """
        if itr_load is None:
            if self.gui:
                self.gui.set_status_text('Press \'go\' to begin.')
            return 0
        else:
            algorithm_file = self._data_files_dir + 'algorithm_itr_%02d.pkl' % itr_load
            self.algorithm = self.data_logger.unpickle(algorithm_file)
            if self.algorithm is None:
                print("Error: cannot find '%s.'" % algorithm_file)
                os._exit(
                    1
                )  # called instead of sys.exit(), since this is in a thread

            if self.gui:
                self.gui.set_status_text((
                    'Resuming training from algorithm state at iteration %d.\n'
                    + 'Press \'go\' to begin.') % itr_load)
            return itr_load + 1

    def _take_sample(self, itr, cond, i):
        """
        Collect a sample from the agent.
        Args:
            itr: Iteration number.
            cond: Condition number.
            i: Sample number.
        Returns: None
        """
        if 'tac_policy' in self.algorithm._hyperparams and self.algorithm.iteration_count > 0:
            pol = PolicyTAC(
                self.algorithm,
                self.algorithm._hyperparams['tac_policy']['history'])
            if 'tac' in self.algorithm._hyperparams:
                self.agent.T = self.algorithm._hyperparams['tac'][
                    'T']  # Use possibly larger T for on-policy sampling
        elif self.algorithm._hyperparams[
                'sample_on_policy'] and self.algorithm.iteration_count > 0:
            pol = self.algorithm.policy_opt.policy
            if 'tac' in self.algorithm._hyperparams:
                self.agent.T = self.algorithm._hyperparams['tac'][
                    'T']  # Use possibly larger T for on-policy sampling
        else:
            pol = self.algorithm.cur[cond].traj_distr
        if self.gui:
            self.gui.set_image_overlays(cond)  # Must call for each new cond.
            redo = True
            while redo:
                while self.gui.mode in ('wait', 'request', 'process'):
                    if self.gui.mode in ('wait', 'process'):
                        time.sleep(0.01)
                        continue
                    # 'request' mode.
                    if self.gui.request == 'test':
                        self.mode = 'test'
                        self._take_policy_samples()
                        self.gui.request = 'stop'
                    if self.gui.request == 'go':
                        self.mode = 'go'
                    if self.gui.request == 'gcm':
                        self.mode = 'gcm'
                    if self.gui.request == 'stop':
                        self.mode = 'stop'

                    if self.gui.request == 'fail':
                        self.gui.err_msg = 'Cannot fail before sampling.'
                    self.gui.process_mode()  # Complete request.

                self.gui.set_status_text(
                    'Sampling: iteration %d, condition %d, sample %d.' %
                    (itr, cond, i))
                self.agent.sample(
                    pol,
                    cond,
                    verbose=(i < self._hyperparams['verbose_trials']),
                    noisy=True,
                    use_TfController=True,
                    rnd=self.agent._hyperparams['random_reset'])

                if self.gui.mode == 'request' and self.gui.request == 'fail':
                    redo = True
                    self.gui.process_mode()
                    self.agent.delete_last_sample(cond)
                else:
                    redo = False
        else:
            self.agent.sample(
                pol,
                cond,
                verbose=(i < self._hyperparams['verbose_trials']),
                noisy=True,
                use_TfController=True,
                reset_cond=None
                if self.agent._hyperparams['random_reset'] else cond)

    def _take_policy_samples(self, N, pol, rnd=False):
        """
        Take samples from the policy to see how it's doing.
        Args:
            N  : number of policy samples to take per condition
            pol: Policy to sample. None for LQR policies.
        Returns: None
        """
        if pol is None:
            pol_samples = [[None] * N] * len(self._test_idx)
            for i, cond in enumerate(self._test_idx, 0):
                for n in trange(
                        N,
                        desc='Taking LQR-policy samples m=%d, cond=%s' %
                    (cond, 'rnd' if rnd else cond)):
                    pol_samples[i][n] = self.agent.sample(
                        self.algorithm.cur[cond].traj_distr,
                        None,
                        verbose=None,
                        save=False,
                        noisy=False,
                        reset_cond=None if rnd else cond,
                        record=False)
            return [SampleList(samples) for samples in pol_samples]
        else:
            conds = self._test_idx if not rnd else [None]
            # stores where the policy has lead to
            pol_samples = [[None] * N] * len(conds)
            for i, cond in enumerate(conds):
                for n in trange(N,
                                desc='Taking %s policy samples cond=%s' %
                                (type(pol).__name__, 'rnd' if rnd else cond)):
                    pol_samples[i][n] = self.agent.sample(pol,
                                                          None,
                                                          verbose=None,
                                                          save=False,
                                                          noisy=False,
                                                          reset_cond=cond,
                                                          record=n < 0)
            return [SampleList(samples) for samples in pol_samples]

    def _log_data(self,
                  itr,
                  traj_sample_lists,
                  pol_sample_lists=None,
                  controller=None):
        """
        Log data and algorithm, and update the GUI.
        Args:
            itr: Iteration number.
            traj_sample_lists: trajectory samples as SampleList object
            pol_sample_lists: policy samples as SampleList object
        Returns: None
        """
        #print("log 0")
        if self.gui:
            self.gui.set_status_text('Logging data and updating GUI.')
            #self.gui.update(itr, self.algorithm, self.agent,
            #                copy.copy(traj_sample_lists), pol_sample_lists)
            #self.gui.save_figure(
            #    self._data_files_dir + ('figure_itr_%02d.png' % itr)
            #    )
        if 'no_sample_logging' in self._hyperparams['common']:
            return
        #print("log 1")
        self.data_logger.pickle(
            self._data_files_dir + ('algorithm_itr_%02d.pkl' % itr),
            copy.copy(self.algorithm))
        #print("log 2")
        self.data_logger.pickle(
            self._data_files_dir + ('%s_samplelist_itr%02d.pkl' %
                                    (self.session_id, itr)),
            copy.copy(traj_sample_lists))
        if controller:
            self.data_logger.pickle(
                self._data_files_dir + ('%s_controller_itr%02d.pkl' %
                                        (self.session_id, itr)),
                copy.copy(controller))
        if pol_sample_lists:
            self.data_logger.pickle(
                self._data_files_dir + ('pol_lqr_sample_itr_%02d.pkl' % itr),
                copy.copy(pol_sample_lists))

    def _end(self):
        """ Finish running and exit. """
        if self.gui:
            self.gui.set_status_text('Training complete.')
            self.gui.end_mode()
            if self._quit_on_end:
                # Quit automatically (for running sequential expts)
                os._exit(1)

    def export_samples(self, traj_sample_lists, sample_type=''):
        """
        Exports trajectoy samples in a compressed numpy file.
        """
        M, N, T, dX, dU = len(traj_sample_lists), len(
            traj_sample_lists[0]), self.agent.T, self.agent.dX, self.agent.dU
        X = np.empty((M, N, T, dX))
        U = np.empty((M, N, T, dU))

        for m in range(M):
            sample_list = traj_sample_lists[m]
            for n in range(N):
                sample = sample_list[n]
                X[m, n] = sample.get_X()
                U[m, n] = sample.get_U()

        np.savez_compressed(
            self._data_files_dir + 'samples%s_%02d' %
            (sample_type, self.iteration_count),
            X=X,
            U=U,
        )

    def export_dynamics(self):
        """
        Exports the local dynamics data in a compressed numpy file.
        """
        M, T, dX, dU = self.algorithm.M, self.agent.T, self.agent.dX, self.agent.dU
        Fm = np.empty((M, T - 1, dX, dX + dU))
        fv = np.empty((M, T - 1, dX))
        dyn_covar = np.empty((M, T - 1, dX, dX))

        for m in range(M):
            dynamics = self.algorithm.cur[m].traj_info.dynamics
            Fm[m] = dynamics.Fm[:-1]
            fv[m] = dynamics.fv[:-1]
            dyn_covar[m] = dynamics.dyn_covar[:-1]

        np.savez_compressed(
            self._data_files_dir + 'dyn_%02d' % self.iteration_count,
            Fm=Fm,
            fv=fv,
            dyn_covar=dyn_covar,
        )

    def export_controllers(self):
        """
        Exports the local controller data in a compressed numpy file.
        """
        M, T, dX, dU = self.algorithm.M, self.agent.T, self.agent.dX, self.agent.dU
        K = np.empty((M, T - 1, dU, dX))
        k = np.empty((M, T - 1, dU))
        prc = np.empty((M, T - 1, dU, dU))

        traj_mu = np.empty((M, T, dX + dU))
        traj_sigma = np.empty((M, T, dX + dU, dX + dU))

        for m in range(M):
            traj = self.algorithm.cur[m].traj_distr
            K[m] = traj.K[:-1]
            k[m] = traj.k[:-1]
            prc[m] = traj.inv_pol_covar[:-1]
            traj_mu[m] = self.algorithm.new_mu[m]
            traj_sigma[m] = self.algorithm.new_sigma[m]

        np.savez_compressed(
            self._data_files_dir + 'ctr_%02d' % self.iteration_count,
            K=K,
            k=k,
            prc=prc,
            traj_mu=traj_mu,
            traj_sigma=traj_sigma,
        )

    def export_times(self):
        """
        Exports timer values into a csv file by appending a line for each iteration.
        """
        header = ','.join(
            self.algorithm.timers.keys()) if self.iteration_count == 0 else ''
        with open(self._data_files_dir + 'timers.csv', 'ab') as out_file:
            np.savetxt(
                out_file,
                np.asarray(
                    [np.asarray([f for f in self.algorithm.timers.values()])]),
                header=header)
예제 #13
0
파일: gps_main.py 프로젝트: w547341387/gps
class GPSMain(object):
    """ Main class to run algorithms and experiments. """
    def __init__(self, config, quit_on_end=False):
        """
        Initialize GPSMain
        Args:
            config: Hyperparameters for experiment
            quit_on_end: When true, quit automatically on completion
        """
        self._quit_on_end = quit_on_end
        self._hyperparams = config
        self._conditions = config['common']['conditions']
        if 'train_conditions' in config['common']:
            self._train_idx = config['common']['train_conditions']
            self._test_idx = config['common']['test_conditions']
        else:
            self._train_idx = range(self._conditions)
            config['common']['train_conditions'] = config['common'][
                'conditions']
            self._hyperparams = config
            self._test_idx = self._train_idx

        self._data_files_dir = config['common']['data_files_dir']

        self.agent = config['agent']['type'](config['agent'])
        self.data_logger = DataLogger()
        self.gui = GPSTrainingGUI(
            config['common']) if config['gui_on'] else None

        config['algorithm']['agent'] = self.agent
        self.algorithm = config['algorithm']['type'](config['algorithm'])

        self.use_mpc = False
        if 'use_mpc' in config['common'] and config['common']['use_mpc']:
            self.use_mpc = True
            config['agent']['T'] = config['agent']['M']
            self.mpc_agent = config['agent']['type'](config['agent'])

            # Algorithm __init__ deleted it
            config['algorithm']['agent'] = self.agent

            self.algorithm.init_mpc(config['num_samples'], config['algorithm'])

    def run(self, itr_load=None):
        """
        Run training by iteratively sampling and taking an iteration.
        Args:
            itr_load: If specified, loads algorithm state from that
                iteration, and resumes training at the next iteration.
        Returns: None
        """
        try:
            itr_start = self._initialize(itr_load)

            for itr in range(itr_start, self._hyperparams['iterations']):
                for cond in self._train_idx:
                    for i in range(self._hyperparams['num_samples']):
                        self._take_sample(itr, cond, i)

                traj_sample_lists = [
                    self.agent.get_samples(cond,
                                           -self._hyperparams['num_samples'])
                    for cond in self._train_idx
                ]

                # Clear agent samples.
                self.agent.clear_samples()

                self._take_iteration(itr, traj_sample_lists)
                pol_sample_lists = self._take_policy_samples()
                self._log_data(itr, traj_sample_lists, pol_sample_lists)
        except Exception as e:
            traceback.print_exception(*sys.exc_info())
        finally:
            self._end()

    def test_policy(self, itr, N):
        """
        Take N policy samples of the algorithm state at iteration itr,
        for testing the policy to see how it is behaving.
        (Called directly from the command line --policy flag).
        Args:
            itr: the iteration from which to take policy samples
            N: the number of policy samples to take
        Returns: None
        """
        algorithm_file = self._data_files_dir + 'algorithm_itr_%02d.pkl' % itr
        self.algorithm = self.data_logger.unpickle(algorithm_file)
        if self.algorithm is None:
            print("Error: cannot find '%s.'" % algorithm_file)
            os._exit(1)  # called instead of sys.exit(), since t
        traj_sample_lists = self.data_logger.unpickle(
            self._data_files_dir + ('traj_sample_itr_%02d.pkl' % itr))

        pol_sample_lists = self._take_policy_samples(N)
        self.data_logger.pickle(
            self._data_files_dir + ('pol_sample_itr_%02d.pkl' % itr),
            copy.copy(pol_sample_lists))

        if self.gui:
            self.gui.update(itr, self.algorithm, self.agent, traj_sample_lists,
                            pol_sample_lists)
            self.gui.set_status_text(
                ('Took %d policy sample(s) from ' +
                 'algorithm state at iteration %d.\n' +
                 'Saved to: data_files/pol_sample_itr_%02d.pkl.\n') %
                (N, itr, itr))

    def _initialize(self, itr_load):
        """
        Initialize from the specified iteration.
        Args:
            itr_load: If specified, loads algorithm state from that
                iteration, and resumes training at the next iteration.
        Returns:
            itr_start: Iteration to start from.
        """
        if itr_load is None:
            if self.gui:
                self.gui.set_status_text('Press \'go\' to begin.')
            return 0
        else:
            algorithm_file = self._data_files_dir + 'algorithm_itr_%02d.pkl' % itr_load
            self.algorithm = self.data_logger.unpickle(algorithm_file)
            if self.algorithm is None:
                print("Error: cannot find '%s.'" % algorithm_file)
                os._exit(
                    1
                )  # called instead of sys.exit(), since this is in a thread

            if self.gui:
                traj_sample_lists = self.data_logger.unpickle(
                    self._data_files_dir +
                    ('traj_sample_itr_%02d.pkl' % itr_load))
                if self.algorithm.cur[0].pol_info:
                    pol_sample_lists = self.data_logger.unpickle(
                        self._data_files_dir +
                        ('pol_sample_itr_%02d.pkl' % itr_load))
                else:
                    pol_sample_lists = None
                self.gui.set_status_text((
                    'Resuming training from algorithm state at iteration %d.\n'
                    + 'Press \'go\' to begin.') % itr_load)
            return itr_load + 1

    def _take_sample(self, itr, cond, i):
        """
        Collect a sample from the agent.
        Args:
            itr: Iteration number.
            cond: Condition number.
            i: Sample number.
        Returns: None
        """
        if self.algorithm._hyperparams['sample_on_policy'] \
                and self.algorithm.iteration_count > 0:
            pol = self.algorithm.policy_opt.policy
        else:
            pol = self.algorithm.cur[cond].traj_distr
        if self.gui:
            self.gui.set_image_overlays(cond)  # Must call for each new cond.
            redo = True
            while redo:
                while self.gui.mode in ('wait', 'request', 'process'):
                    if self.gui.mode in ('wait', 'process'):
                        time.sleep(0.01)
                        continue
                    # 'request' mode.
                    if self.gui.request == 'reset':
                        try:
                            self.agent.reset(cond)
                        except NotImplementedError:
                            self.gui.err_msg = 'Agent reset unimplemented.'
                    elif self.gui.request == 'fail':
                        self.gui.err_msg = 'Cannot fail before sampling.'
                    self.gui.process_mode()  # Complete request.

                self.gui.set_status_text(
                    'Sampling: iteration %d, condition %d, sample %d.' %
                    (itr, cond, i))

                self._roll_out(pol, itr, cond, i)

                if self.gui.mode == 'request' and self.gui.request == 'fail':
                    redo = True
                    self.gui.process_mode()
                    self.agent.delete_last_sample(cond)
                else:
                    redo = False
        else:
            self._roll_out(pol, itr, cond, i)

    def _roll_out(self, pol, itr, cond, i):
        if self.use_mpc and itr > 0:
            T = self.agent.T
            M = self.mpc_agent.T
            N = int(ceil(T / (M - 1.)))
            X_t = self.agent.x0[cond]

            # Only forward pass one time per cond,
            # because this same for all sample
            if i == 0:
                # Note: At this time algorithm.prev = algorithm.cur,
                #       and prev.traj_info already have x0mu, x0sigma.
                self.off_prior, _ = self.algorithm.traj_opt.forward(
                    pol, self.algorithm.prev[cond].traj_info)
                self.agent.publish_plan(self.off_prior)

            if type(self.algorithm) == AlgorithmTrajOpt:
                pol_info = None
            else:
                pol_info = self.algorithm.cur[cond].pol_info

            for n in range(N):
                # Note: M-1 because action[M] = [0,0].
                t_traj = n * (M - 1)
                reset = True if (n == 0) else False

                mpc_pol, mpc_state = self.algorithm.mpc[cond][i].update(
                    n, X_t, self.off_prior, pol,
                    self.algorithm.cur[cond].traj_info, t_traj, pol_info)
                self.agent.publish_plan(mpc_state, True)
                new_sample = self.mpc_agent.sample(
                    mpc_pol,
                    cond,
                    reset=reset,
                    noisy=True,
                    verbose=(i < self._hyperparams['verbose_trials']))
                X_t = new_sample.get_X(t=M - 1)
            """
             Merge sample for optimize offline trajectory distribution
            """
            full_sample = Sample(self.agent)
            sample_lists = self.mpc_agent.get_samples(cond)
            keys = sample_lists[0]._data.keys()
            t = 0
            for sample in sample_lists:
                for m in range(sample.T - 1):
                    for sensor in keys:
                        full_sample.set(sensor, sample.get(sensor, m), t)
                    t = t + 1
                    if t + 1 > T:
                        break

            self.agent._samples[cond].append(full_sample)
            # Clear agent samples.
            self.mpc_agent.clear_samples()
        else:
            self.agent.sample(
                pol, cond, verbose=(i < self._hyperparams['verbose_trials']))

    def _take_iteration(self, itr, sample_lists):
        """
        Take an iteration of the algorithm.
        Args:
            itr: Iteration number.
        Returns: None
        """
        if self.gui:
            self.gui.set_status_text('Calculating.')
            self.gui.start_display_calculating()
        self.algorithm.iteration(sample_lists)
        if self.gui:
            self.gui.stop_display_calculating()

    def _take_policy_samples(self, N=None):
        """
        Take samples from the policy to see how it's doing.
        Args:
            N  : number of policy samples to take per condition
        Returns: None
        """
        if 'verbose_policy_trials' not in self._hyperparams:
            # AlgorithmTrajOpt
            return None
        verbose = self._hyperparams['verbose_policy_trials']
        if self.gui:
            self.gui.set_status_text('Taking policy samples.')
        pol_samples = [[None] for _ in range(len(self._test_idx))]
        # Since this isn't noisy, just take one sample.
        # TODO: Make this noisy? Add hyperparam?
        # TODO: Take at all conditions for GUI?
        for cond in range(len(self._test_idx)):
            pol_samples[cond][0] = self.agent.sample(
                self.algorithm.policy_opt.policy,
                self._test_idx[cond],
                verbose=verbose,
                save=False,
                noisy=False)
        return [SampleList(samples) for samples in pol_samples]

    def _log_data(self, itr, traj_sample_lists, pol_sample_lists=None):
        """
        Log data and algorithm, and update the GUI.
        Args:
            itr: Iteration number.
            traj_sample_lists: trajectory samples as SampleList object
            pol_sample_lists: policy samples as SampleList object
        Returns: None
        """
        if self.gui:
            self.gui.set_status_text('Logging data and updating GUI.')
            self.gui.update(itr, self.algorithm, self.agent, traj_sample_lists,
                            pol_sample_lists)
            self.gui.save_figure(self._data_files_dir +
                                 ('figure_itr_%02d.png' % itr))
        if 'no_sample_logging' in self._hyperparams['common']:
            return
        self.data_logger.pickle(
            self._data_files_dir + ('algorithm_itr_%02d.pkl' % itr),
            copy.copy(self.algorithm))
        self.data_logger.pickle(
            self._data_files_dir + ('traj_sample_itr_%02d.pkl' % itr),
            copy.copy(traj_sample_lists))
        if pol_sample_lists:
            self.data_logger.pickle(
                self._data_files_dir + ('pol_sample_itr_%02d.pkl' % itr),
                copy.copy(pol_sample_lists))

    def _end(self):
        """ Finish running and exit. """
        if self.gui:
            self.gui.set_status_text('Training complete.')
            self.gui.end_mode()
            if self._quit_on_end:
                # Quit automatically (for running sequential expts)
                os._exit(1)
예제 #14
0
class GPSMain(object):
    """ Main class to run algorithms and experiments. """
    def __init__(self, config, quit_on_end=False):
        """
        Initialize GPSMain
        Args:
            config: Hyperparameters for experiment
            quit_on_end: When true, quit automatically on completion
        """
        self._quit_on_end = quit_on_end
        self._hyperparams = config
        self._conditions = config['common']['conditions']
        if 'train_conditions' in config['common']:
            self._train_idx = config['common']['train_conditions']
            self._test_idx = config['common']['test_conditions']
        else:
            self._train_idx = range(self._conditions)
            config['common']['train_conditions'] = config['common'][
                'conditions']
            self._hyperparams = config
            self._test_idx = self._train_idx

        self._data_files_dir = config['common']['data_files_dir']

        self.agent = config['agent']['type'](config['agent'])
        self.data_logger = DataLogger()
        self.gui = GPSTrainingGUI(
            config['common']) if config['gui_on'] else None

        config['algorithm']['agent'] = self.agent
        self.algorithm = config['algorithm']['type'](config['algorithm'])

    def run(self, config, itr_load=None):
        """
        Run training by iteratively sampling and taking an iteration.
        Args:
            itr_load: If specified, loads algorithm state from that
                iteration, and resumes training at the next iteration.
        Returns: None
        """
        itr_start = self._initialize(itr_load)
        position_train = self.data_logger.unpickle(
            './position/position_train.pkl')
        T = self.algorithm.T
        N = self._hyperparams['num_samples']
        dU = self.algorithm.dU
        for num_pos in range(position_train.shape[0]):
            """ load train position and reset agent model. """
            for cond in self._train_idx:
                self._hyperparams['agent']['pos_body_offset'][
                    cond] = position_train[num_pos]
            self.agent.reset_model(self._hyperparams)

            # initial train array
            train_prc = np.zeros((0, T, dU, dU))
            train_mu = np.zeros((0, T, dU))
            train_obs_data = np.zeros((0, T, self.algorithm.dO))
            train_wt = np.zeros((0, T))

            # initial variables
            count_suc = 0

            for itr in range(itr_start, self._hyperparams['iterations']):
                print('******************num_pos:************', num_pos)
                print('______________________itr:____________', itr)
                for cond in self._train_idx:
                    for i in range(self._hyperparams['num_samples']):
                        self._take_sample(itr, cond, i)

                traj_sample_lists = [
                    self.agent.get_samples(cond,
                                           -self._hyperparams['num_samples'])
                    for cond in self._train_idx
                ]

                # calculate the distance of  the end-effector to target position
                ee_pos = self.agent.get_ee_pos(cond)[:3]
                target_pos = self.agent._hyperparams['target_ee_pos'][:3]
                distance_pos = ee_pos - target_pos
                distance_ee = np.sqrt(distance_pos.dot(distance_pos))
                print('distance ee:', distance_ee)

                # collect the successful sample to train global policy
                if distance_ee <= 0.06:
                    count_suc += 1
                    tgt_mu, tgt_prc, obs_data, tgt_wt = self.train_prepare(
                        traj_sample_lists)
                    train_mu = np.concatenate((train_mu, tgt_mu))
                    train_prc = np.concatenate((train_prc, tgt_prc))
                    train_obs_data = np.concatenate((train_obs_data, obs_data))
                    train_wt = np.concatenate((train_wt, tgt_wt))

                # Clear agent samples.
                self.agent.clear_samples()

                # if get enough sample, then break
                if count_suc > 8:
                    break

                self._take_iteration(itr, traj_sample_lists)
                # pol_sample_lists = self._take_policy_samples()
                # self._log_data(itr, traj_sample_lists, pol_sample_lists)

            # get previous sample via previous policy
            if num_pos > 0:
                previous_mu = self.algorithm.get_previous_sample(
                    train_obs_data)
            # train NN with good samples
            self.algorithm.policy_opt.update(train_obs_data, train_mu,
                                             train_prc, train_wt)
            # test the trained in the current position
            self.test_current_policy()
            # reset the algorithm to the initial algorithm for the next position
            self.algorithm.reset_alg()

        self._end()

    def train_prepare(self, sample_lists):
        """
        prepare the train data of the sample lists
        Args:
            sample_lists: sample list from agent

        Returns:
            target mu, prc, obs_data, wt

        """
        algorithm = self.algorithm
        dU, dO, T = algorithm.dU, algorithm.dO, algorithm.T
        obs_data, tgt_mu = np.zeros((0, T, dO)), np.zeros((0, T, dU))
        tgt_prc = np.zeros((0, T, dU, dU))
        tgt_wt = np.zeros((0, T))
        wt_origin = 0.01 * np.ones(T)
        for m in range(algorithm.M):
            samples = sample_lists[m]
            X = samples.get_X()
            N = len(samples)
            prc = np.zeros((N, T, dU, dU))
            mu = np.zeros((N, T, dU))
            wt = np.zeros((N, T))

            traj = algorithm.cur[m].traj_distr
            for t in range(T):
                prc[:, t, :, :] = np.tile(traj.inv_pol_covar[t, :, :],
                                          [N, 1, 1])
                for i in range(N):
                    mu[i,
                       t, :] = (traj.K[t, :, :].dot(X[i, t, :]) + traj.k[t, :])
                wt[:, t].fill(wt_origin[t])
            tgt_mu = np.concatenate((tgt_mu, mu))
            tgt_prc = np.concatenate((tgt_prc, prc))
            obs_data = np.concatenate((obs_data, samples.get_obs()))
            tgt_wt = np.concatenate((tgt_wt, wt))

        return tgt_mu, tgt_prc, obs_data, tgt_wt

    def test_policy(self, itr, N):
        """
        Take N policy samples of the algorithm state at iteration itr,
        for testing the policy to see how it is behaving.
        (Called directly from the command line --policy flag).
        Args:
            itr: the iteration from which to take policy samples
            N: the number of policy samples to take
        Returns: None
        """
        algorithm_file = self._data_files_dir + 'algorithm_itr_%02d.pkl' % itr
        self.algorithm = self.data_logger.unpickle(algorithm_file)
        if self.algorithm is None:
            print("Error: cannot find '%s.'" % algorithm_file)
            os._exit(1)  # called instead of sys.exit(), since t
        traj_sample_lists = self.data_logger.unpickle(
            self._data_files_dir + ('traj_sample_itr_%02d.pkl' % itr))

        pol_sample_lists = self._take_policy_samples(N)
        self.data_logger.pickle(
            self._data_files_dir + ('pol_sample_itr_%02d.pkl' % itr),
            copy.copy(pol_sample_lists))

        if self.gui:
            self.gui.update(itr, self.algorithm, self.agent, traj_sample_lists,
                            pol_sample_lists)
            self.gui.set_status_text(
                ('Took %d policy sample(s) from ' +
                 'algorithm state at iteration %d.\n' +
                 'Saved to: data_files/pol_sample_itr_%02d.pkl.\n') %
                (N, itr, itr))

    def _initialize(self, itr_load):
        """
        Initialize from the specified iteration.
        Args:
            itr_load: If specified, loads algorithm state from that
                iteration, and resumes training at the next iteration.
        Returns:
            itr_start: Iteration to start from.
        """
        if itr_load is None:
            if self.gui:
                self.gui.set_status_text('Press \'go\' to begin.')
            return 0
        else:
            algorithm_file = self._data_files_dir + 'algorithm_itr_%02d.pkl' % itr_load
            self.algorithm = self.data_logger.unpickle(algorithm_file)
            if self.algorithm is None:
                print("Error: cannot find '%s.'" % algorithm_file)
                os._exit(
                    1
                )  # called instead of sys.exit(), since this is in a thread

            if self.gui:
                traj_sample_lists = self.data_logger.unpickle(
                    self._data_files_dir +
                    ('traj_sample_itr_%02d.pkl' % itr_load))
                if self.algorithm.cur[0].pol_info:
                    pol_sample_lists = self.data_logger.unpickle(
                        self._data_files_dir +
                        ('pol_sample_itr_%02d.pkl' % itr_load))
                else:
                    pol_sample_lists = None
                self.gui.set_status_text((
                    'Resuming training from algorithm state at iteration %d.\n'
                    + 'Press \'go\' to begin.') % itr_load)
            return itr_load + 1

    def _take_sample(self, itr, cond, i):
        """
        Collect a sample from the agent.
        Args:
            itr: Iteration number.
            cond: Condition number.
            i: Sample number.
        Returns: None
        """
        if self.algorithm._hyperparams['sample_on_policy'] \
                and self.algorithm.iteration_count > 0:
            pol = self.algorithm.policy_opt.policy
        else:
            pol = self.algorithm.cur[cond].traj_distr
        if self.gui:
            self.gui.set_image_overlays(cond)  # Must call for each new cond.
            redo = True
            while redo:
                while self.gui.mode in ('wait', 'request', 'process'):
                    if self.gui.mode in ('wait', 'process'):
                        time.sleep(0.01)
                        continue
                    # 'request' mode.
                    if self.gui.request == 'reset':
                        try:
                            self.agent.reset(cond)
                        except NotImplementedError:
                            self.gui.err_msg = 'Agent reset unimplemented.'
                    elif self.gui.request == 'fail':
                        self.gui.err_msg = 'Cannot fail before sampling.'
                    self.gui.process_mode()  # Complete request.

                self.gui.set_status_text(
                    'Sampling: iteration %d, condition %d, sample %d.' %
                    (itr, cond, i))
                self.agent.sample(
                    pol,
                    cond,
                    verbose=(i < self._hyperparams['verbose_trials']))

                if self.gui.mode == 'request' and self.gui.request == 'fail':
                    redo = True
                    self.gui.process_mode()
                    self.agent.delete_last_sample(cond)
                else:
                    redo = False
        else:
            self.agent.sample(
                pol, cond, verbose=(i < self._hyperparams['verbose_trials']))

    def _take_iteration(self, itr, sample_lists):
        """
        Take an iteration of the algorithm.
        Args:
            itr: Iteration number.
        Returns: None
        """
        if self.gui:
            self.gui.set_status_text('Calculating.')
            self.gui.start_display_calculating()
        self.algorithm.iteration(sample_lists)
        if self.gui:
            self.gui.stop_display_calculating()

    def _take_policy_samples(self, N=None):
        """
        Take samples from the policy to see how it's doing.
        Args:
            N  : number of policy samples to take per condition
        Returns: None
        """
        if 'verbose_policy_trials' not in self._hyperparams:
            # AlgorithmTrajOpt
            return None
        verbose = self._hyperparams['verbose_policy_trials']
        if self.gui:
            self.gui.set_status_text('Taking policy samples.')
        pol_samples = [[None] for _ in range(len(self._test_idx))]
        # Since this isn't noisy, just take one sample.
        # TODO: Make this noisy? Add hyperparam?
        # TODO: Take at all conditions for GUI?
        for cond in range(len(self._test_idx)):
            pol_samples[cond][0] = self.agent.sample(
                self.algorithm.policy_opt.policy,
                self._test_idx[cond],
                verbose=verbose,
                save=False,
                noisy=False)
        return [SampleList(samples) for samples in pol_samples]

    def test_current_policy(self):
        """
        test the current NN policy in the current position
        Returns:

        """
        verbose = self._hyperparams['verbose_policy_trials']
        for cond in self._train_idx:
            samples = self.agent.sample(self.algorithm.policy_opt.policy,
                                        cond,
                                        verbose=verbose,
                                        save=False,
                                        noisy=False)

    def _log_data(self, itr, traj_sample_lists, pol_sample_lists=None):
        """
        Log data and algorithm, and update the GUI.
        Args:
            itr: Iteration number.
            traj_sample_lists: trajectory samples as SampleList object
            pol_sample_lists: policy samples as SampleList object
        Returns: None
        """
        if self.gui:
            self.gui.set_status_text('Logging data and updating GUI.')
            self.gui.update(itr, self.algorithm, self.agent, traj_sample_lists,
                            pol_sample_lists)
            self.gui.save_figure(self._data_files_dir +
                                 ('figure_itr_%02d.png' % itr))
        if 'no_sample_logging' in self._hyperparams['common']:
            return
        self.data_logger.pickle(
            self._data_files_dir + ('algorithm_itr_%02d.pkl' % itr),
            copy.copy(self.algorithm))
        self.data_logger.pickle(
            self._data_files_dir + ('traj_sample_itr_%02d.pkl' % itr),
            copy.copy(traj_sample_lists))
        if pol_sample_lists:
            self.data_logger.pickle(
                self._data_files_dir + ('pol_sample_itr_%02d.pkl' % itr),
                copy.copy(pol_sample_lists))

    def _end(self):
        """ Finish running and exit. """
        if self.gui:
            self.gui.set_status_text('Training complete.')
            self.gui.end_mode()
            if self._quit_on_end:
                # Quit automatically (for running sequential expts)
                os._exit(1)
예제 #15
0
class GPSMain(object):
    """ Main class to run algorithms and experiments. """
    def __init__(self, config, quit_on_end=False):
        """
        Initialize GPSMain
        Args:
            config: Hyperparameters for experiment
            quit_on_end: When true, quit automatically on completion
        """
        # This controls whether there is a learned reset or not
        # There is a learned reset when special_reset = True
        try:
            self.special_reset = config['agent']['special_reset']
        except:
            self.special_reset = False
        self._quit_on_end = quit_on_end
        self._hyperparams = config
        self._conditions = config['common']['conditions']
        # There's going to be as many reset conditions as conditions
        self._reset_conditions = self._conditions
        if 'train_conditions' in config['common']:
            self._train_idx = config['common']['train_conditions']
            self._test_idx = config['common']['test_conditions']
        else:
            self._train_idx = range(self._conditions)
            config['common']['train_conditions'] = config['common'][
                'conditions']
            self._hyperparams = config
            self._test_idx = self._train_idx

        self._data_files_dir = config['common']['data_files_dir']

        self.agent = config['agent']['type'](config['agent'])
        self.data_logger = DataLogger()
        self.gui = GPSTrainingGUI(
            config['common']) if config['gui_on'] else None

        config['algorithm']['agent'] = self.agent
        self.algorithm = config['algorithm']['type'](config['algorithm'])

        import tensorflow as tf
        try:
            with tf.variable_scope(
                    'reset'):  # to avoid variable naming conflicts
                # Gonna make the algorithm for the reset ones as well
                self.reset_algorithm = config['reset_algorithm']['type'](
                    config['algorithm'])
        except:
            pass

        self.saved_algorithm = copy.deepcopy(
            self.algorithm)  # Save this newly initialized alg
        # If you want to warm start the algorithm BADMM with learned iLQG controllers
        self.diff_warm_start = True
        # If you want to warm start the neural network with some pretrained network
        self.nn_warm_start = False
        # The following variables determine the pretrained network path
        # Change these if you want to take the other policy type
        attention, structure = 'time', 'mlp'
        self.policy_path = os.path.join(
            self._data_files_dir,
            os.path.join('{}_{}'.format(attention, structure), 'policy'))
        try:
            self.old_policy_opt = copy.deepcopy(self.algorithm.policy_opt)
        except:
            pass

        pdb.set_trace()

    # Specially initialize the algorithm after with pretrained things
    def special_init_alg(self):

        resumed_alg = self.algorithm  # We picked this up by resuming
        # SPECIFIC TO BADMM AND MDGPS
        if (type(self.saved_algorithm) is AlgorithmBADMM and not(type(resumed_alg) is AlgorithmBADMM)) or \
        (type(self.saved_algorithm) is AlgorithmMDGPS and not(type(resumed_alg) is AlgorithmMDGPS)):
            self.algorithm = self.saved_algorithm  # Return it to the new type we want to use
            # Keep a copy of these hyperparams and stuff
            theParams = copy.deepcopy(self.algorithm._hyperparams)
            # For all the instance variables in the resumed algorithm
            for item in resumed_alg.__dict__:
                # Set the attributes accordingly or something like that
                setattr(self.algorithm, item, resumed_alg.__dict__[item])

            # Except for the hyperparams, those need to be saved or something
            self.algorithm._hyperparams = theParams
            self.algorithm.re_init_pol_info(theParams)  # Reinitialize this
            # Get rid of the prev data, this messes up the linear algebra stuff
            self.algorithm.prev = [
                IterationData() for _ in range(self.algorithm.M)
            ]
            self.algorithm.iteration_count = 0  # Pretend this is the first iteration

        pdb.set_trace()

    def run(self, itr_load=None):
        """
        Run training by iteratively sampling and taking an iteration.
        Args:
            itr_load: If specified, loads algorithm state from that
                iteration, and resumes training at the next iteration.
        Returns: None
        """
        try:
            itr_start = self._initialize(itr_load)

            for itr in range(itr_start, self._hyperparams['iterations']):
                for cond in self._train_idx:
                    for i in range(self._hyperparams['num_samples']):
                        self._take_sample(itr, cond, i)
                        if self.special_reset:  # If there is a special reset
                            # Take a special sample
                            self._take_sample(itr, cond, i, reset=True)

                traj_sample_lists = [
                    self.agent.get_samples(cond,
                                           -self._hyperparams['num_samples'])
                    for cond in self._train_idx
                ]
                if self.special_reset:  # Again, if there is a special reset
                    reset_traj_sample_lists = [
                        self.agent.get_reset_samples(
                            cond, -self._hyperparams['num_samples'])
                        for cond in self._train_idx
                    ]
                    self._take_iteration(itr,
                                         reset_traj_sample_lists,
                                         reset=True)

                # Clear agent samples. (Including the reset ones)
                self.agent.clear_samples()
                #pdb.set_trace()
                self._take_iteration(itr, traj_sample_lists)
                pol_sample_lists = self._take_policy_samples()
                self._log_data(itr, traj_sample_lists, pol_sample_lists)
        except Exception as e:
            traceback.print_exception(*sys.exc_info())
        finally:
            self._end()

    def test_policy(self, itr, N):
        """
        Take N policy samples of the algorithm state at iteration itr,
        for testing the policy to see how it is behaving.
        (Called directly from the command line --policy flag).
        Args:
            itr: the iteration from which to take policy samples
            N: the number of policy samples to take
        Returns: None
        """
        algorithm_file = self._data_files_dir + 'algorithm_itr_%02d.pkl' % itr
        self.algorithm = self.data_logger.unpickle(algorithm_file)
        if self.algorithm is None:
            print("Error: cannot find '%s.'" % algorithm_file)
            os._exit(1)  # called instead of sys.exit(), since t
        traj_sample_lists = self.data_logger.unpickle(
            self._data_files_dir + ('traj_sample_itr_%02d.pkl' % itr))

        pol_sample_lists = self._take_policy_samples(N)
        self.data_logger.pickle(
            self._data_files_dir + ('pol_sample_itr_%02d.pkl' % itr),
            copy.copy(pol_sample_lists))

        if self.gui:
            self.gui.update(itr, self.algorithm, self.agent, traj_sample_lists,
                            pol_sample_lists)
            self.gui.set_status_text(
                ('Took %d policy sample(s) from ' +
                 'algorithm state at iteration %d.\n' +
                 'Saved to: data_files/pol_sample_itr_%02d.pkl.\n') %
                (N, itr, itr))

    def _initialize(self, itr_load):
        """
        Initialize from the specified iteration.
        Args:
            itr_load: If specified, loads algorithm state from that
                iteration, and resumes training at the next iteration.
        Returns:
            itr_start: Iteration to start from.
        """
        if itr_load is None:
            if self.gui:
                self.gui.set_status_text('Press \'go\' to begin.')
            return 0
        else:
            print("TRAINING STARTING OFF FROM ITR_LOAD" + str(itr_load))
            algorithm_file = self._data_files_dir + 'algorithm_itr_%02d.pkl' % itr_load
            self.algorithm = self.data_logger.unpickle(algorithm_file)
            if self.algorithm is None:
                print("Error: cannot find '%s.'" % algorithm_file)
                os._exit(
                    1
                )  # called instead of sys.exit(), since this is in a thread
            if self.diff_warm_start:  # If we are warm starting the algorithm
                self.special_init_alg(
                )  # Call the special initialization method lmao
            if self.nn_warm_start:  # If we are warm starting the neural network
                # Restore the policy opt with the policy in the given policy path
                self.algorithm.policy_opt.restore_model(self.policy_path)

            self.agent.itr_load = itr_load  # Set the iter load
            pdb.set_trace()

            if self.gui:
                traj_sample_lists = self.data_logger.unpickle(
                    self._data_files_dir +
                    ('traj_sample_itr_%02d.pkl' % itr_load))
                if self.algorithm.cur[0].pol_info:
                    pol_sample_lists = self.data_logger.unpickle(
                        self._data_files_dir +
                        ('pol_sample_itr_%02d.pkl' % itr_load))
                else:
                    pol_sample_lists = None
                self.gui.set_status_text((
                    'Resuming training from algorithm state at iteration %d.\n'
                    + 'Press \'go\' to begin.') % itr_load)
            return itr_load + 1

    # The reset is if this sample is a reset sample
    def _take_sample(self, itr, cond, i, reset=False):
        """
        Collect a sample from the agent.
        Args:
            itr: Iteration number.
            cond: Condition number.
            i: Sample number.
        Returns: None
        """
        if self.algorithm._hyperparams['sample_on_policy'] \
                and self.algorithm.iteration_count > 0:
            # Use the reset algorithm policy if this is a reset sample
            if reset:
                pol = self.reset_algorithm.policy_opt.policy
            else:  # Otherwise we are gonna use the primary algorithm
                pol = self.algorithm.policy_opt.policy
        else:
            if reset:
                pol = self.reset_algorithm.cur[cond].traj_distr
            else:
                pol = self.algorithm.cur[cond].traj_distr
        if self.gui:
            self.gui.set_image_overlays(cond)  # Must call for each new cond.
            redo = True
            while redo:
                while self.gui.mode in ('wait', 'request', 'process'):
                    if self.gui.mode in ('wait', 'process'):
                        time.sleep(0.01)
                        continue
                    # 'request' mode.
                    if self.gui.request == 'reset':
                        try:
                            self.agent.reset(cond)
                        except NotImplementedError:
                            self.gui.err_msg = 'Agent reset unimplemented.'
                    elif self.gui.request == 'fail':
                        self.gui.err_msg = 'Cannot fail before sampling.'
                    self.gui.process_mode()  # Complete request.
                if reset:  # If we are doing a reset one
                    self.gui.set_status_text(
                        'Sampling reset: iteration %d, condition %d, sample %d.'
                        % (itr, cond, i))
                else:  # Otherwise this is a normal sample or something
                    self.gui.set_status_text(
                        'Sampling: iteration %d, condition %d, sample %d.' %
                        (itr, cond, i))
                if reset:
                    self.agent.reset_time = True  # Set the agent reset_time to true
                    # Then it will be a special sample
                self.agent.sample(
                    pol,
                    cond,
                    verbose=(i < self._hyperparams['verbose_trials']))

                if self.gui.mode == 'request' and self.gui.request == 'fail':
                    redo = True
                    self.gui.process_mode()
                    self.agent.delete_last_sample(cond)
                else:
                    redo = False
        else:
            verbose = i < self._hyperparams['verbose_trials']
            self.agent.sample(pol, cond, verbose=verbose, reset=reset)

    def _take_iteration(self, itr, sample_lists, reset=False):
        """
        Take an iteration of the algorithm.
        Args:
            itr: Iteration number.
        Returns: None
        """
        if self.gui:
            self.gui.set_status_text('Calculating.')
            self.gui.start_display_calculating()
        self.agent.reset(0)  # so the arm doesn't roll

        if reset:  # If we are resetting, iterate for reset algorithm
            #pass # Uncomment this and comment below line if you don't want to learn reset
            self.reset_algorithm.iteration(sample_lists)
        else:  # Otherwise, iterate for normal algorithm
            self.algorithm.iteration(sample_lists)
        if self.gui:
            self.gui.stop_display_calculating()

    def _take_policy_samples(self, N=None):
        """
        Take samples from the policy to see how it's doing.
        Args:
            N  : number of policy samples to take per condition
        Returns: None
        """
        if 'verbose_policy_trials' not in self._hyperparams:
            # AlgorithmTrajOpt
            return None
        verbose = self._hyperparams['verbose_policy_trials']
        if self.gui:
            self.gui.set_status_text('Taking policy samples.')
        pol_samples = [[None] for _ in range(len(self._test_idx))]
        # Since this isn't noisy, just take one sample.
        # TODO: Make this noisy? Add hyperparam?
        # TODO: Take at all conditions for GUI?
        for cond in range(len(self._test_idx)):
            pol_samples[cond][0] = self.agent.sample(
                self.algorithm.policy_opt.policy,
                self._test_idx[cond],
                verbose=verbose,
                save=False,
                noisy=True)  #otherPol=self.algorithm.cur[cond].traj_distr)
        #pdb.set_trace()
        return [SampleList(samples) for samples in pol_samples]

    def _log_data(self, itr, traj_sample_lists, pol_sample_lists=None):
        """
        Log data and algorithm, and update the GUI.
        Args:
            itr: Iteration number.
            traj_sample_lists: trajectory samples as SampleList object
            pol_sample_lists: policy samples as SampleList object
        Returns: None
        """
        if self.gui:
            self.gui.set_status_text('Logging data and updating GUI.')
            self.gui.update(itr, self.algorithm, self.agent, traj_sample_lists,
                            pol_sample_lists)
            self.gui.save_figure(self._data_files_dir +
                                 ('figure_itr_%02d.png' % itr))
        if 'no_sample_logging' in self._hyperparams['common']:
            return
        self.data_logger.pickle(
            self._data_files_dir + ('algorithm_itr_%02d.pkl' % itr),
            copy.copy(self.algorithm))
        self.data_logger.pickle(
            self._data_files_dir + ('traj_sample_itr_%02d.pkl' % itr),
            copy.copy(traj_sample_lists))
        # Maybe pickle the agent to help out?
        self.data_logger.pickle(
            self._data_files_dir + ('agent_itr_%02d.pkl' % itr),
            copy.copy(self.agent))
        if pol_sample_lists:
            self.data_logger.pickle(
                self._data_files_dir + ('pol_sample_itr_%02d.pkl' % itr),
                copy.copy(pol_sample_lists))

    def _end(self):
        """ Finish running and exit. """
        if self.gui:
            self.gui.set_status_text('Training complete.')
            self.gui.end_mode()
            if self._quit_on_end:
                # Quit automatically (for running sequential expts)
                os._exit(1)
예제 #16
0
class GPSMain(object):
    """ Main class to run algorithms and experiments. """
    def __init__(self, config, quit_on_end=False):
        """
        Initialize GPSMain
        Args:
            config: Hyperparameters for experiment
            quit_on_end: When true, quit automatically on completion
        """
        self._quit_on_end = quit_on_end
        self._hyperparams = config
        self._conditions = config['common']['conditions']
        if 'train_conditions' in config['common']:
            self._train_idx = config['common']['train_conditions']
            self._test_idx = config['common']['test_conditions']
        else:
            self._train_idx = range(self._conditions)
            config['common']['train_conditions'] = config['common'][
                'conditions']
            self._hyperparams = config
            self._test_idx = self._train_idx

        self._data_files_dir = config['common']['data_files_dir']

        self.agent = config['agent']['type'](config['agent'])
        self.data_logger = DataLogger()
        self.gui = GPSTrainingGUI(
            config['common']) if config['gui_on'] else None

        config['algorithm']['agent'] = self.agent
        self.algorithm = config['algorithm']['type'](config['algorithm'])

    def run(self, time_experiment, exper_condition, itr_load=None):
        """
        Run training by iteratively sampling and taking an iteration.
        Args:
            itr_load: If specified, loads algorithm state from that
                iteration, and resumes training at the next iteration.
        Returns: None
        """
        itr_start = self._initialize(itr_load)

        # test_position = self.data_logger.unpickle('./position/%d/%d/test_position.pkl'
        #                                           % (time_experiment, exper_condition))
        self.target_ee_point = self.agent._hyperparams['target_ee_points'][:3]

        for itr in range(itr_start, self._hyperparams['iterations']):
            print('itr******:  %d   **********' % itr)
            for cond in self._train_idx:
                for i in range(self._hyperparams['num_samples']):
                    self._take_sample(itr, cond, i)

            traj_sample_lists = [
                self.agent.get_samples(cond, -self._hyperparams['num_samples'])
                for cond in self._train_idx
            ]

            # Clear agent samples.
            self.agent.clear_samples()

            self._take_iteration(itr, traj_sample_lists)
            pol_sample_lists = self._take_policy_samples()

            self._log_data(itr, traj_sample_lists, pol_sample_lists)
        """ test policy and collect costs"""
        """
        gradually add the distance of agent position
        """
        center_position = 0.02
        radius = 0.02
        max_error_bound = 0.02
        directory = 9
        for test_condition in range(7):
            # test_position = self.generate_position(center_position, radius, 30, max_error_bound)
            test_position = self.data_logger.unpickle(
                './position/test_position_%d.pkl' % (test_condition + 1))
            costs, position_suc_count, distance = self.test_cost(
                test_position, len(pol_sample_lists))
            print('distance:', distance)
            # add the position_suc_count
            if test_condition == 0:
                #augement array
                all_pos_suc_count = np.expand_dims(position_suc_count, axis=0)
                all_distance = np.expand_dims(distance, axis=0)
            else:
                all_pos_suc_count = np.vstack(
                    (all_pos_suc_count, position_suc_count))
                all_distance = np.vstack((all_distance, distance))

            costs = costs.reshape(costs.shape[0] * costs.shape[1])
            mean_cost = np.array([np.mean(costs)])
            center_position = center_position + radius * 2

        self._end()
        return costs, mean_cost, all_pos_suc_count, all_distance

    def generate_position(self, cposition, radius, conditions,
                          max_error_bound):
        # all_positions = np.zeros(0)

        while True:
            all_positions = np.array([cposition, -cposition, 0])
            center_position = np.array([cposition, -cposition, 0])
            for i in range(conditions):
                position = np.random.uniform(cposition - radius,
                                             cposition + radius, 3)
                while True:
                    position[2] = 0
                    position[1] = -position[1]
                    area = (position - center_position).dot(position -
                                                            center_position)
                    # area = np.sum(np.multiply(position - center_position, position - center_position))
                    if area <= radius**2:
                        # print(area)
                        break
                    position = np.random.uniform(cposition - radius,
                                                 cposition + radius, 3)
                position = np.floor(position * 1000) / 1000.0
                all_positions = np.concatenate((all_positions, position))
            all_positions = np.reshape(all_positions,
                                       [all_positions.shape[0] / 3, 3])
            # print(all_positions[:, 1])
            # print('mean:')
            # print(np.mean(all_positions, axis=0))
            mean_position = np.mean(all_positions, axis=0)
            # mean_error1 = np.fabs(mean_position[0] - 0.11)
            # mean_error2 = np.fabs(mean_position[1] + 0.11)
            mean_error1 = np.fabs(mean_position[0] -
                                  (cposition - max_error_bound))
            mean_error2 = np.fabs(mean_position[1] +
                                  (cposition - max_error_bound))
            if mean_error1 < max_error_bound and mean_error2 < max_error_bound:
                print('mean:')
                print(np.mean(all_positions, axis=0))
                break
        print(all_positions)
        print(all_positions.shape)
        return all_positions

    def test_cost(self, positions, train_cond):
        """
        test policy and collect costs
        Args:
            positions: test position from test_position.pkl

        Returns:
            cost:   mean cost of all test position
            total_suc:  successful pegging trial count  1:successful    0:fail

        """
        iteration = positions.shape[0] / train_cond
        total_costs = list()
        total_ee_points = list()
        total_suc = np.zeros(0)
        total_distance = np.zeros(0)
        for itr in range(iteration):
            for cond in self._train_idx:
                self._hyperparams['agent']['pos_body_offset'][
                    cond] = positions[itr + cond]
            self.agent.reset_model(self._hyperparams)
            _, cost, ee_points = self._test_policy_samples()
            for cond in self._train_idx:
                total_ee_points.append(ee_points[cond])
            total_costs.append(cost)
        print("total_costs:", total_costs)
        for i in range(len(total_ee_points)):
            ee_error = total_ee_points[i][:3] - self.target_ee_point
            distance = ee_error.dot(ee_error)**0.5
            if (distance < 0.06):
                total_suc = np.concatenate((total_suc, np.array([1])))
            else:
                total_suc = np.concatenate((total_suc, np.array([0])))
            total_distance = np.concatenate(
                (total_distance, np.array([distance])))
        return np.array(total_costs), total_suc, total_distance

    def test_policy(self, itr, N):
        """
        Take N policy samples of the algorithm state at iteration itr,
        for testing the policy to see how it is behaving.
        (Called directly from the command line --policy flag).
        Args:
            itr: the iteration from which to take policy samples
            N: the number of policy samples to take
        Returns: None
        """
        algorithm_file = self._data_files_dir + 'algorithm_itr_%02d.pkl' % itr
        self.algorithm = self.data_logger.unpickle(algorithm_file)
        if self.algorithm is None:
            print("Error: cannot find '%s.'" % algorithm_file)
            os._exit(1)  # called instead of sys.exit(), since t
        traj_sample_lists = self.data_logger.unpickle(
            self._data_files_dir + ('traj_sample_itr_%02d.pkl' % itr))

        pol_sample_lists = self._take_policy_samples(N)
        self.data_logger.pickle(
            self._data_files_dir + ('pol_sample_itr_%02d.pkl' % itr),
            copy.copy(pol_sample_lists))

        if self.gui:
            self.gui.update(itr, self.algorithm, self.agent, traj_sample_lists,
                            pol_sample_lists)
            self.gui.set_status_text(
                ('Took %d policy sample(s) from ' +
                 'algorithm state at iteration %d.\n' +
                 'Saved to: data_files/pol_sample_itr_%02d.pkl.\n') %
                (N, itr, itr))

    def _initialize(self, itr_load):
        """
        Initialize from the specified iteration.
        Args:
            itr_load: If specified, loads algorithm state from that
                iteration, and resumes training at the next iteration.
        Returns:
            itr_start: Iteration to start from.
        """
        if itr_load is None:
            if self.gui:
                self.gui.set_status_text('Press \'go\' to begin.')
            return 0
        else:
            algorithm_file = self._data_files_dir + 'algorithm_itr_%02d.pkl' % itr_load
            self.algorithm = self.data_logger.unpickle(algorithm_file)
            if self.algorithm is None:
                print("Error: cannot find '%s.'" % algorithm_file)
                os._exit(
                    1
                )  # called instead of sys.exit(), since this is in a thread

            if self.gui:
                traj_sample_lists = self.data_logger.unpickle(
                    self._data_files_dir +
                    ('traj_sample_itr_%02d.pkl' % itr_load))
                if self.algorithm.cur[0].pol_info:
                    pol_sample_lists = self.data_logger.unpickle(
                        self._data_files_dir +
                        ('pol_sample_itr_%02d.pkl' % itr_load))
                else:
                    pol_sample_lists = None
                self.gui.set_status_text((
                    'Resuming training from algorithm state at iteration %d.\n'
                    + 'Press \'go\' to begin.') % itr_load)
            return itr_load + 1

    def _take_sample(self, itr, cond, i):
        """
        Collect a sample from the agent.
        Args:
            itr: Iteration number.
            cond: Condition number.
            i: Sample number.
        Returns: None
        """
        if self.algorithm._hyperparams['sample_on_policy'] \
                and self.algorithm.iteration_count > 0:
            pol = self.algorithm.policy_opt.policy
        else:
            pol = self.algorithm.cur[cond].traj_distr
        if self.gui:
            self.gui.set_image_overlays(cond)  # Must call for each new cond.
            redo = True
            while redo:
                while self.gui.mode in ('wait', 'request', 'process'):
                    if self.gui.mode in ('wait', 'process'):
                        time.sleep(0.01)
                        continue
                    # 'request' mode.
                    if self.gui.request == 'reset':
                        try:
                            self.agent.reset(cond)
                        except NotImplementedError:
                            self.gui.err_msg = 'Agent reset unimplemented.'
                    elif self.gui.request == 'fail':
                        self.gui.err_msg = 'Cannot fail before sampling.'
                    self.gui.process_mode()  # Complete request.

                self.gui.set_status_text(
                    'Sampling: iteration %d, condition %d, sample %d.' %
                    (itr, cond, i))
                self.agent.sample(
                    pol,
                    cond,
                    verbose=(i < self._hyperparams['verbose_trials']))

                if self.gui.mode == 'request' and self.gui.request == 'fail':
                    redo = True
                    self.gui.process_mode()
                    self.agent.delete_last_sample(cond)
                else:
                    redo = False
        else:
            self.agent.sample(
                pol, cond, verbose=(i < self._hyperparams['verbose_trials']))

    def _take_iteration(self, itr, sample_lists):
        """
        Take an iteration of the algorithm.
        Args:
            itr: Iteration number.
        Returns: None
        """
        if self.gui:
            self.gui.set_status_text('Calculating.')
            self.gui.start_display_calculating()
        self.algorithm.iteration(sample_lists)
        if self.gui:
            self.gui.stop_display_calculating()

    def _take_policy_samples(self, N=None):
        """
        Take samples from the policy to see how it's doing.
        Args:
            N  : number of policy samples to take per condition
        Returns: None
        """
        if 'verbose_policy_trials' not in self._hyperparams:
            # AlgorithmTrajOpt
            return None
        verbose = self._hyperparams['verbose_policy_trials']
        if self.gui:
            self.gui.set_status_text('Taking policy samples.')
        pol_samples = [[None] for _ in range(len(self._test_idx))]
        # Since this isn't noisy, just take one sample.
        # TODO: Make this noisy? Add hyperparam?
        # TODO: Take at all conditions for GUI?
        for cond in range(len(self._test_idx)):
            pol_samples[cond][0] = self.agent.sample(
                self.algorithm.policy_opt.policy,
                self._test_idx[cond],
                verbose=verbose,
                save=False,
                noisy=False)
        return [SampleList(samples) for samples in pol_samples]

    def _test_policy_samples(self, N=None):
        """
        test sample from the policy and collect the costs
        Args:
            N:

        Returns:
            samples
            costs:      list of cost for each condition
            ee_point:   list of ee_point for each condition

        """
        if 'verbose_policy_trials' not in self._hyperparams:
            return None
        verbose = self._hyperparams['verbose_policy_trials']
        pol_samples = [[None] for _ in range(len(self._test_idx))]
        costs = list()
        ee_points = list()
        for cond in range(len(self._test_idx)):
            pol_samples[cond][0] = self.agent.sample(
                self.algorithm.policy_opt.policy,
                self._test_idx[cond],
                verbose=verbose,
                save=False,
                noisy=False)
            # in algorithm.py: _eval_cost
            policy_cost = self.algorithm.cost[0].eval(pol_samples[cond][0])[0]
            policy_cost = np.sum(policy_cost)  #100 step
            costs.append(policy_cost)
            ee_points.append(self.agent.get_ee_point(cond))
        return [SampleList(samples)
                for samples in pol_samples], costs, ee_points

    def _log_data(self, itr, traj_sample_lists, pol_sample_lists=None):
        """
        Log data and algorithm, and update the GUI.
        Args:
            itr: Iteration number.
            traj_sample_lists: trajectory samples as SampleList object
            pol_sample_lists: policy samples as SampleList object
        Returns: None
        """
        if self.gui:
            self.gui.set_status_text('Logging data and updating GUI.')
            self.gui.update(itr, self.algorithm, self.agent, traj_sample_lists,
                            pol_sample_lists)
            self.gui.save_figure(self._data_files_dir +
                                 ('figure_itr_%02d.png' % itr))
        if 'no_sample_logging' in self._hyperparams['common']:
            return
        self.data_logger.pickle(
            self._data_files_dir + ('algorithm_itr_%02d.pkl' % itr),
            copy.copy(self.algorithm))
        self.data_logger.pickle(
            self._data_files_dir + ('traj_sample_itr_%02d.pkl' % itr),
            copy.copy(traj_sample_lists))
        if pol_sample_lists:
            self.data_logger.pickle(
                self._data_files_dir + ('pol_sample_itr_%02d.pkl' % itr),
                copy.copy(pol_sample_lists))

    def _end(self):
        """ Finish running and exit. """
        if self.gui:
            self.gui.set_status_text('Training complete.')
            self.gui.end_mode()
            if self._quit_on_end:
                # Quit automatically (for running sequential expts)
                os._exit(1)
예제 #17
0
class GPSMain(object):
    """ Main class to run algorithms and experiments. """
    def __init__(self, config, quit_on_end=False):
        """
        Initialize GPSMain
        Args:
            config: Hyperparameters for experiment
            quit_on_end: When true, quit automatically on completion
        """

        # VARIABLE INITIALISATION #

        self._quit_on_end = quit_on_end
        self._hyperparams = config
        self._conditions = config['common']['conditions']
        if 'train_conditions' in config['common']:
            self._train_idx = config['common']['train_conditions']
            self._test_idx = config['common']['test_conditions']
        else:
            self._train_idx = range(self._conditions)
            config['common']['train_conditions'] = config['common'][
                'conditions']
            self._hyperparams = config
            self._test_idx = self._train_idx

        self._data_files_dir = config['common']['data_files_dir']

        self.agent = config['agent']['type'](config['agent'])
        self.data_logger = DataLogger()
        self.gui = GPSTrainingGUI(
            config['common']) if config['gui_on'] else None

        config['algorithm']['agent'] = self.agent
        self.algorithm = config['algorithm']['type'](config['algorithm'])

    def run(self, itr_load=None):
        """
        Run training by iteratively sampling and taking an iteration.
        Args:
            itr_load: If specified, loads algorithm state from that
                iteration, and resumes training at the next iteration.
        Returns: None
        """
        try:
            itr_start = self._initialize(
                itr_load
            )  # perform initialization, and return starting iteration

            for itr in range(
                    itr_start,
                    self._hyperparams['iterations']):  # iterations of GPS
                for cond in self._train_idx:  # number of distinct problem instantiations for training
                    for i in range(self._hyperparams['num_samples']
                                   ):  # number of gps samples
                        self._take_sample(
                            itr, cond, i
                        )  # take sample from LQR controller, and store in agent object

                traj_sample_lists = [
                    self.agent.get_samples(cond,
                                           -self._hyperparams['num_samples'])
                    for cond in self._train_idx
                ]  # store list of trajectory samples, for all training conditions on this iteration

                # Clear agent samples.
                self.agent.clear_samples()

                self._take_iteration(
                    itr, traj_sample_lists
                )  # take iteration, updating dynamics model, and minimising cost of LQR controller
                pol_sample_lists = self._take_policy_samples(
                )  # take samples of current LQR policy, to see how it's doing
                self._log_data(
                    itr, traj_sample_lists, pol_sample_lists
                )  # logs samples from the policy before and after the iteration taken, and the algorithm object, into pickled bit-stream files
        except Exception as e:
            traceback.print_exception(*sys.exc_info())  # else, catch exception
        finally:  # ensures exexcution within try statement, whether exception occured or not
            self._end(
            )  # show process ended in gui, and close python if user argument selected as such

    def test_policy(self, itr, N):
        """
        Take N policy samples of the algorithm state at iteration itr,
        for testing the policy to see how it is behaving.
        (Called directly from the command line --policy flag).
        Args:
            itr: the iteration from which to take policy samples
            N: the number of policy samples to take
        Returns: None
        """
        algorithm_file = self._data_files_dir + 'algorithm_itr_%02d.pkl' % itr
        self.algorithm = self.data_logger.unpickle(algorithm_file)
        if self.algorithm is None:
            print("Error: cannot find '%s.'" % algorithm_file)
            os._exit(1)  # called instead of sys.exit(), since t
        traj_sample_lists = self.data_logger.unpickle(
            self._data_files_dir + ('traj_sample_itr_%02d.pkl' % itr))

        pol_sample_lists = self._take_policy_samples(N)
        self.data_logger.pickle(
            self._data_files_dir + ('pol_sample_itr_%02d.pkl' % itr),
            copy.copy(pol_sample_lists)
        )  # save the policy samples as a pickled bit-stream file

        if self.gui:  # if using gui
            self.gui.update(itr, self.algorithm, self.agent, traj_sample_lists,
                            pol_sample_lists)
            self.gui.set_status_text(
                ('Took %d policy sample(s) from ' +
                 'algorithm state at iteration %d.\n' +
                 'Saved to: data_files/pol_sample_itr_%02d.pkl.\n') %
                (N, itr, itr))

    def _initialize(self, itr_load):
        """
        Initialize from the specified iteration.
        Args:
            itr_load: If specified, loads algorithm state from that
                iteration, and resumes training at the next iteration.
        Returns:
            itr_start: Iteration to start from.
        """
        if itr_load is None:  # if no iteration to load from specified
            if self.gui:  # if using gui
                self.gui.set_status_text('Press \'go\' to begin.'
                                         )  # inform user to start the process
            return 0  # return iteration number as 0, ie start from beginning
        else:
            algorithm_file = self._data_files_dir + 'algorithm_itr_%02d.pkl' % itr_load  # define algorithm file based on desired iteration
            self.algorithm = self.data_logger.unpickle(
                algorithm_file
            )  # Read string from file and interpret as pickle data stream, reconstructing and returning the original object
            if self.algorithm is None:
                print("Error: cannot find '%s.'" % algorithm_file)
                os._exit(
                    1
                )  # called instead of sys.exit(), since this is in a thread

            if self.gui:  # if using gui
                traj_sample_lists = self.data_logger.unpickle(
                    self._data_files_dir +
                    ('traj_sample_itr_%02d.pkl' % itr_load)
                )  # unpickle traj sample lists for current iterations
                if self.algorithm.cur[0].pol_info:  # if policy info available
                    pol_sample_lists = self.data_logger.unpickle(
                        self._data_files_dir +
                        ('pol_sample_itr_%02d.pkl' % itr_load)
                    )  # unpickle policy sample lists for current iteration
                else:
                    pol_sample_lists = None
                self.gui.set_status_text((
                    'Resuming training from algorithm state at iteration %d.\n'
                    + 'Press \'go\' to begin.') % itr_load
                                         )  # inform user to start the process
            return itr_load + 1  # return iteration number to begin working from

    def _take_sample(self, itr, cond, i):
        """
        Collect a sample from the agent.
        Args:
            itr: Iteration number.
            cond: Condition number.
            i: Sample number.
        Returns: None
        """
        if self.algorithm._hyperparams['sample_on_policy'] \
                and self.algorithm.iteration_count > 0:
            pol = self.algorithm.policy_opt.policy  # There is NO policy optimisation for simple traj opt algorithm
        else:
            pol = self.algorithm.cur[
                cond].traj_distr  # initialise LQR controller based on current trajectories
            # Note: policy.act(array(7)->X, array(empty)->obs, int->t, array(2)->noise) to retrieve policy actions conditioned on state
        if self.gui:  # if using gui
            self.gui.set_image_overlays(cond)  # Must call for each new cond.
            redo = True
            while redo:
                # For Handling GUI Requests 'stop', 'reset', 'go', 'fail' #
                while self.gui.mode in (
                        'wait', 'request', 'process'
                ):  # while gui is waiting, requesting, or processing
                    if self.gui.mode in (
                            'wait', 'process'):  # if waiting or processing,
                        time.sleep(0.01)
                        continue  # continue again at start of while loop
                    # 'request' mode.
                    if self.gui.request == 'reset':
                        try:
                            self.agent.reset(cond)
                        except NotImplementedError:
                            self.gui.err_msg = 'Agent reset unimplemented.'
                    elif self.gui.request == 'fail':
                        self.gui.err_msg = 'Cannot fail before sampling.'
                    self.gui.process_mode()  # Complete request.

                self.gui.set_status_text(
                    'Sampling: iteration %d, condition %d, sample %d.' %
                    (itr, cond, i))  # display to user
                self.agent.sample(
                    pol,
                    cond,
                    verbose=(i < self._hyperparams['verbose_trials'])
                )  # run trail using policy, and save trajectory into agent object (AgentBox2D or AgentMuJoCo class)

                if self.gui.mode == 'request' and self.gui.request == 'fail':
                    redo = True
                    self.gui.process_mode()  # Complete request.
                    self.agent.delete_last_sample(
                        cond
                    )  # delete last sample, and redo, on account of fail request
                else:
                    redo = False
        else:
            self.agent.sample(
                pol, cond, verbose=(i < self._hyperparams['verbose_trials'])
            )  # run trail using policy, and save trajectory into agent object (AgentBox2D or AgentMuJoCo class)

    def _take_iteration(self, itr, sample_lists):
        """
        Take an iteration of the algorithm.
        Args:
            itr: Iteration number.
        Returns: None
        """
        if self.gui:
            self.gui.set_status_text('Calculating.')
            self.gui.start_display_calculating()
        self.algorithm.iteration(
            sample_lists
        )  # take iteration, training LQR controller and updating dynamics model
        if self.gui:
            self.gui.stop_display_calculating()

    def _take_policy_samples(self, N=None):
        """
        Take samples from the policy to see how it's doing.
        Args:
            N  : number of policy samples to take per condition
        Returns: None
        """
        if 'verbose_policy_trials' not in self._hyperparams:
            # AlgorithmTrajOpt
            return None
        verbose = self._hyperparams[
            'verbose_policy_trials']  # bool as to whether to use verbose trails or not
        if self.gui:
            self.gui.set_status_text('Taking policy samples.')
        pol_samples = [[None] for _ in range(len(self._test_idx))]
        # Since this isn't noisy, just take one sample.
        # TODO: Make this noisy? Add hyperparam?
        # TODO: Take at all conditions for GUI?
        for cond in range(len(self._test_idx)):
            pol_samples[cond][0] = self.agent.sample(
                self.algorithm.policy_opt.policy,
                self._test_idx[cond],
                verbose=verbose,
                save=False,
                noisy=False
            )  # iterate through problem instantiations, accumulating policy samples from LQR policy
        return [SampleList(samples) for samples in pol_samples
                ]  # return samples, held in SampleList objects

    def _log_data(self, itr, traj_sample_lists, pol_sample_lists=None):
        """
        Log data and algorithm, and update the GUI.
        Args:
            itr: Iteration number.
            traj_sample_lists: trajectory samples as SampleList object
            pol_sample_lists: policy samples as SampleList object
        Returns: None
        """
        if self.gui:  # if using gui
            self.gui.set_status_text('Logging data and updating GUI.')
            self.gui.update(itr, self.algorithm, self.agent, traj_sample_lists,
                            pol_sample_lists)
            self.gui.save_figure(self._data_files_dir +
                                 ('figure_itr_%02d.png' % itr))
        if 'no_sample_logging' in self._hyperparams['common']:
            return
        self.data_logger.pickle(
            self._data_files_dir + ('algorithm_itr_%02d.pkl' % itr),
            copy.copy(self.algorithm)
        )  # store current algorithm iteration object as bit-stream file
        self.data_logger.pickle(
            self._data_files_dir + ('traj_sample_itr_%02d.pkl' % itr),
            copy.copy(traj_sample_lists)
        )  # store current trajectory sample list as bit-stream file
        if pol_sample_lists:
            self.data_logger.pickle(
                self._data_files_dir + ('pol_sample_itr_%02d.pkl' % itr),
                copy.copy(pol_sample_lists)
            )  # store current policy sample list as bit-stream file

    def _end(self):
        """ Finish running and exit. """
        if self.gui:  # if using gui
            self.gui.set_status_text('Training complete.')
            self.gui.end_mode()
            if self._quit_on_end:  # user argument for main
                # Quit automatically (for running sequential expts)
                os._exit(1)  # exit python altogether
예제 #18
0
class GPSMain(object):
    """ Main class to run algorithms and experiments. """
    def __init__(self, config, quit_on_end=False):
        """
		Initialize GPSMain
		Args:
			config: Hyperparameters for experiment
			quit_on_end: When true, quit automatically on completion
		"""
        self._quit_on_end = quit_on_end
        self._hyperparams = config
        # print(config)
        self._conditions = config['common']['conditions']
        if 'train_conditions' in config['common']:
            self._train_idx = config['common']['train_conditions']
            self._test_idx = config['common']['test_conditions']
        else:
            self._train_idx = range(self._conditions)
            config['common']['train_conditions'] = config['common'][
                'conditions']
            self._hyperparams = config
            self._test_idx = self._train_idx

        self._data_files_dir = config['common']['data_files_dir']
        self.agent = config['agent']['type'](config['agent'])
        self.data_logger = DataLogger()
        self.gui = GPSTrainingGUI(
            config['common']) if config['gui_on'] else None

        config['algorithm']['agent'] = self.agent
        # hard code to pass the map_state and target_state
        config['algorithm']['cost']['costs'][1]['data_types'][3][
            'target_state'] = config['agent']['target_state']
        config['algorithm']['cost']['costs'][1]['data_types'][3][
            'map_size'] = config['agent']['map_size']
        # config['algorithm']['cost']['costs'][1]['data_types'][3]['map_size'] = CUT_MAP_SIZE

        if len(config['algorithm']['cost']['costs']) > 2:
            # temporarily deprecated, not considering collision cost
            # including cost_collision
            config['algorithm']['cost']['costs'][2]['data_types'][3][
                'target_state'] = config['agent']['target_state']
            config['algorithm']['cost']['costs'][2]['data_types'][3][
                'map_size'] = config['agent']['map_size']
            config['algorithm']['cost']['costs'][2]['data_types'][3][
                'map_state'] = config['agent']['map_state']
        # print(config['algorithm'])
        self.algorithm = config['algorithm']['type'](config['algorithm'])

        # Modified by RH
        self.finishing_time = None
        self.U = None
        self.final_pos = None
        self.samples = []
        self.quick_sample = None
        # self.map_size = config['agent']['map_size']
        self.map_size = CUT_MAP_SIZE
        self.display_center = config['agent']['display_center']

    def run(self, itr_load=None):
        """
		Run training by iteratively sampling and taking an iteration.
		Args:
			itr_load: If specified, loads algorithm state from that
				iteration, and resumes training at the next iteration.
		Returns: None
		"""
        try:
            itr_start = self._initialize(itr_load)

            for itr in range(itr_start, self._hyperparams['iterations']):
                print("iter_num", itr)
                for cond in self._train_idx:
                    for i in range(self._hyperparams['num_samples']):
                        self._take_sample(itr, cond, i)

                traj_sample_lists = [
                    self.agent.get_samples(cond,
                                           -self._hyperparams['num_samples'])
                    for cond in self._train_idx
                ]
                for cond in self._train_idx:
                    self.samples.append(traj_sample_lists[cond].get_samples())
                # from sample_list .get_X() return one sample if given index, else return all
                # Clear agent samples.
                self.agent.clear_samples()
                # The inner loop is done in _take_iteration (self.algorithm.iteration())
                # take iteration is like training
                self._take_iteration(itr, traj_sample_lists)
                # pol_sample_list is valid only for testing
                pol_sample_lists = self._take_policy_samples()
                self._log_data(itr, traj_sample_lists, pol_sample_lists)

                if self.finishing_time:
                    break
                    # TODO: try to save and compare to find the minimal speed later after all iterations

            if not self.finishing_time:
                print("sorry, not hit")
                # TODO: find a relatively proper (nearest) sample, and return the final_pos by get_X(index)
                # assume one condition first
                # if collect samples then can only get last iteration
                self.samples = np.concatenate(self.samples)

        except Exception as e:
            traceback.print_exception(*sys.exc_info())
        finally:
            self._end()

    def test_policy(self, itr, N):
        """
		Take N policy samples of the algorithm state at iteration itr,
		for testing the policy to see how it is behaving.
		(Called directly from the command line --policy flag).
		Args:
			itr: the iteration from which to take policy samples
			N: the number of policy samples to take
		Returns: None
		"""

        # modified by RH
        # originally, algo_itr, traj_itr, policy_itr are all the same
        # algo_itr = itr
        # traj_itr = itr
        # polilcy_itr = itr

        algo_itr = 18
        traj_itr = 1
        polilcy_itr = 18

        algorithm_file = self._data_files_dir + 'algorithm_itr_%02d.pkl' % itr
        self.algorithm = self.data_logger.unpickle(algorithm_file)
        if self.algorithm is None:
            print("Error: cannot find '%s.'" % algorithm_file)
            os._exit(1)  # called instead of sys.exit(), since t
        traj_sample_lists = self.data_logger.unpickle(
            self._data_files_dir + ('traj_sample_itr_%02d.pkl' % itr))
        print("algorithm")
        print(self.algorithm)
        print("traj_sample_lists")
        print(traj_sample_lists)
        print("N", N)
        pol_sample_lists = self._take_policy_samples(N)
        print("pol_sample_lists")
        print(pol_sample_lists)
        self.data_logger.pickle(
            self._data_files_dir + ('pol_sample_itr_%02d.pkl' % itr),
            copy.copy(pol_sample_lists))

        if self.gui:
            self.gui.update(itr, self.algorithm, self.agent, traj_sample_lists,
                            pol_sample_lists)
            self.gui.set_status_text(
                ('Took %d policy sample(s) from ' +
                 'algorithm state at iteration %d.\n' +
                 'Saved to: data_files/pol_sample_itr_%02d.pkl.\n') %
                (N, itr, itr))

    def _initialize(self, itr_load):
        """
		Initialize from the specified iteration.
		Args:
			itr_load: If specified, loads algorithm state from that
				iteration, and resumes training at the next iteration.
		Returns:
			itr_start: Iteration to start from.
		"""
        if itr_load is None:
            if self.gui:
                self.gui.set_status_text('Press \'go\' to begin.')
            return 0
        else:
            algorithm_file = self._data_files_dir + 'algorithm_itr_%02d.pkl' % itr_load
            self.algorithm = self.data_logger.unpickle(algorithm_file)
            if self.algorithm is None:
                print("Error: cannot find '%s.'" % algorithm_file)
                os._exit(
                    1
                )  # called instead of sys.exit(), since this is in a thread

            if self.gui:
                traj_sample_lists = self.data_logger.unpickle(
                    self._data_files_dir +
                    ('traj_sample_itr_%02d.pkl' % itr_load))
                if self.algorithm.cur[0].pol_info:
                    pol_sample_lists = self.data_logger.unpickle(
                        self._data_files_dir +
                        ('pol_sample_itr_%02d.pkl' % itr_load))
                else:
                    pol_sample_lists = None
                self.gui.set_status_text((
                    'Resuming training from algorithm state at iteration %d.\n'
                    + 'Press \'go\' to begin.') % itr_load)
            return itr_load + 1

    def _take_sample(self, itr, cond, i):
        """
		Collect a sample from the agent.
		Args:
			itr: Iteration number.
			cond: Condition number.
			i: Sample number.
		Returns: None
		"""
        if self.algorithm._hyperparams['sample_on_policy'] \
          and self.algorithm.iteration_count > 0:
            pol = self.algorithm.policy_opt.policy
        else:
            pol = self.algorithm.cur[cond].traj_distr
        if self.gui:
            self.gui.set_image_overlays(cond)  # Must call for each new cond.
            redo = True
            while redo:
                while self.gui.mode in ('wait', 'request', 'process'):
                    if self.gui.mode in ('wait', 'process'):
                        time.sleep(0.01)
                        continue
                    # 'request' mode.
                    if self.gui.request == 'reset':
                        try:
                            self.agent.reset(cond)
                        except NotImplementedError:
                            self.gui.err_msg = 'Agent reset unimplemented.'
                    elif self.gui.request == 'fail':
                        self.gui.err_msg = 'Cannot fail before sampling.'
                    self.gui.process_mode()  # Complete request.

                self.gui.set_status_text(
                    'Sampling: iteration %d, condition %d, sample %d.' %
                    (itr, cond, i))
                # self.agent.sample(
                #     pol, cond,
                #     verbose=(i < self._hyperparams['verbose_trials'])
                # )
                new_sample = self.agent.sample(
                    pol,
                    cond,
                    verbose=(i < self._hyperparams['verbose_trials']))

                if self.gui.mode == 'request' and self.gui.request == 'fail':
                    redo = True
                    self.gui.process_mode()
                    self.agent.delete_last_sample(cond)
                else:
                    redo = False
        else:
            # self.agent.sample(
            #     pol, cond,
            #     verbose=(i < self._hyperparams['verbose_trials'])
            # )
            new_sample = self.agent.sample(
                pol, cond, verbose=(i < self._hyperparams['verbose_trials']))
        if type(self.agent) == AgentBus and self.agent.finishing_time:
            # save the sample
            if (self.finishing_time is None) or (self.agent.finishing_time <
                                                 self.finishing_time):
                self.finishing_time = self.agent.finishing_time
                print("agent_bus t= ", self.finishing_time)
                # for sample class, get_X\get_U methods return current and future ttimesteps
                self.U = new_sample.get_U()[:self.finishing_time]
                self.X = new_sample.get_X()[:self.finishing_time]
                self.final_pos = new_sample.get_X()[self.finishing_time]
                # TODO: pass map_size if necessary
                # print("return final_pos", [self.final_pos[0]+self.display_center[0], self.display_center[1]-self.final_pos[1], self.final_pos[2]])

    def _take_iteration(self, itr, sample_lists):
        """
		Take an iteration of the algorithm.
		Args:
			itr: Iteration number.
		Returns: None
		"""
        if self.gui:
            self.gui.set_status_text('Calculating.')
            self.gui.start_display_calculating()
        self.algorithm.iteration(sample_lists)
        if self.gui:
            self.gui.stop_display_calculating()

    def _take_policy_samples(self, N=None):
        """
		Take samples from the policy to see how it's doing.
		Args:
			N  : number of policy samples to take per condition
		Returns: None
		"""
        if 'verbose_policy_trials' not in self._hyperparams:
            # AlgorithmTrajOpt
            # raise ValueError("Verbose absent")
            return None
        verbose = self._hyperparams['verbose_policy_trials']
        if self.gui:
            self.gui.set_status_text('Taking policy samples.')
        pol_samples = [[None] for _ in range(len(self._test_idx))]
        # Since this isn't noisy, just take one sample.
        # TODO: Make this noisy? Add hyperparam?
        # TODO: Take at all conditions for GUI?
        print("any sample in _take_policy_samples?")
        for cond in range(len(self._test_idx)):
            # for different cond
            # original code
            pol_samples[cond][0] = self.agent.sample(
                self.algorithm.policy_opt.policy,
                self._test_idx[cond],
                verbose=verbose,
                save=False,
                noisy=False)
            print(pol_samples[cond][0])
        return [SampleList(samples) for samples in pol_samples]

    def _log_data(self, itr, traj_sample_lists, pol_sample_lists=None):
        """
		Log data and algorithm, and update the GUI.
		Args:
			itr: Iteration number.
			traj_sample_lists: trajectory samples as SampleList object
			pol_sample_lists: policy samples as SampleList object
		Returns: None
		"""
        if self.gui:
            self.gui.set_status_text('Logging data and updating GUI.')
            self.gui.update(itr, self.algorithm, self.agent, traj_sample_lists,
                            pol_sample_lists)
            self.gui.save_figure(self._data_files_dir +
                                 ('figure_itr_%02d.png' % itr))
        if 'no_sample_logging' in self._hyperparams['common']:
            return
        self.data_logger.pickle(
            self._data_files_dir + ('algorithm_itr_%02d.pkl' % itr),
            copy.copy(self.algorithm))
        self.data_logger.pickle(
            self._data_files_dir + ('traj_sample_itr_%02d.pkl' % itr),
            copy.copy(traj_sample_lists))
        if pol_sample_lists:
            self.data_logger.pickle(
                self._data_files_dir + ('pol_sample_itr_%02d.pkl' % itr),
                copy.copy(pol_sample_lists))

    def _end(self):
        """ Finish running and exit. """
        if self.gui:
            self.gui.set_status_text('Training complete.')
            self.gui.end_mode()
            if self._quit_on_end:
                # Quit automatically (for running sequential expts)
                os._exit(1)
        else:
            return
예제 #19
0
    def __init__(self, config, quit_on_end=False):
        """
        Initialize GPSMain
        Args:
            config: Hyperparameters for experiment
            quit_on_end: When true, quit automatically on completion
        """
        # This controls whether there is a learned reset or not
        # There is a learned reset when special_reset = True
        try:
            self.special_reset = config['agent']['special_reset']
        except:
            self.special_reset = False
        self._quit_on_end = quit_on_end
        self._hyperparams = config
        self._conditions = config['common']['conditions']
        # There's going to be as many reset conditions as conditions
        self._reset_conditions = self._conditions
        if 'train_conditions' in config['common']:
            self._train_idx = config['common']['train_conditions']
            self._test_idx = config['common']['test_conditions']
        else:
            self._train_idx = range(self._conditions)
            config['common']['train_conditions'] = config['common'][
                'conditions']
            self._hyperparams = config
            self._test_idx = self._train_idx

        self._data_files_dir = config['common']['data_files_dir']

        self.agent = config['agent']['type'](config['agent'])
        self.data_logger = DataLogger()
        self.gui = GPSTrainingGUI(
            config['common']) if config['gui_on'] else None

        config['algorithm']['agent'] = self.agent
        self.algorithm = config['algorithm']['type'](config['algorithm'])

        import tensorflow as tf
        try:
            with tf.variable_scope(
                    'reset'):  # to avoid variable naming conflicts
                # Gonna make the algorithm for the reset ones as well
                self.reset_algorithm = config['reset_algorithm']['type'](
                    config['algorithm'])
        except:
            pass

        self.saved_algorithm = copy.deepcopy(
            self.algorithm)  # Save this newly initialized alg
        # If you want to warm start the algorithm BADMM with learned iLQG controllers
        self.diff_warm_start = True
        # If you want to warm start the neural network with some pretrained network
        self.nn_warm_start = False
        # The following variables determine the pretrained network path
        # Change these if you want to take the other policy type
        attention, structure = 'time', 'mlp'
        self.policy_path = os.path.join(
            self._data_files_dir,
            os.path.join('{}_{}'.format(attention, structure), 'policy'))
        try:
            self.old_policy_opt = copy.deepcopy(self.algorithm.policy_opt)
        except:
            pass

        pdb.set_trace()
예제 #20
0
class GPSMain(object):
    """ Main class to run algorithms and experiments. """
    def __init__(self, config):
        self._hyperparams = config
        self._conditions = config['common']['conditions']
        self._data_files_dir = config['common']['data_files_dir']

        self.agent = config['agent']['type'](config['agent'])
        self.data_logger = DataLogger()
        self.gui = GPSTrainingGUI(
            config['common']) if config['gui_on'] else None

        config['algorithm']['agent'] = self.agent
        self.algorithm = config['algorithm']['type'](config['algorithm'])

    def run(self, itr_load=None):
        """
        Run training by iteratively sampling and taking an iteration.
        Args:
            itr_load: If specified, loads algorithm state from that
                iteration, and resumes training at the next iteration.
        Returns: None
        """
        itr_start = self._initialize(itr_load)
        ########### ############ to execute without training ########### Brook ###########
        self.executable = False
        if self.executable:
            while (self.executable):
                self.agent.execute(self.algorithm.policy_opt.policy)
        else:
            for itr in range(itr_start, self._hyperparams['iterations']):
                for cond in range(self._conditions):
                    for i in range(self._hyperparams['num_samples']):
                        self._take_sample(itr, cond, i)

                traj_sample_lists = [
                    self.agent.get_samples(cond,
                                           -self._hyperparams['num_samples'])
                    for cond in range(self._conditions)
                ]
                self._take_iteration(itr, traj_sample_lists)
                pol_sample_lists = self._take_policy_samples()
                self._log_data(itr, traj_sample_lists, pol_sample_lists)


#############################################################################################
        self._end()

    def test_policy(self, itr, N):
        """
        Take N policy samples of the algorithm state at iteration itr,
        for testing the policy to see how it is behaving.
        (Called directly from the command line --policy flag).
        Args:
            itr: the iteration from which to take policy samples
            N: the number of policy samples to take
        Returns: None
        """
        algorithm_file = self._data_files_dir + 'algorithm_itr_%02d.pkl' % itr
        self.algorithm = self.data_logger.unpickle(algorithm_file)
        if self.algorithm is None:
            print("Error: cannot find '%s.'" % algorithm_file)
            os._exit(1)  # called instead of sys.exit(), since t
        traj_sample_lists = self.data_logger.unpickle(
            self._data_files_dir + ('traj_sample_itr_%02d.pkl' % itr))

        pol_sample_lists = self._take_policy_samples(N)
        self.data_logger.pickle(
            self._data_files_dir + ('pol_sample_itr_%02d.pkl' % itr),
            copy.copy(pol_sample_lists))

        if self.gui:
            self.gui.update(itr, self.algorithm, self.agent, traj_sample_lists,
                            pol_sample_lists)
            self.gui.set_status_text(
                ('Took %d policy sample(s) from ' +
                 'algorithm state at iteration %d.\n' +
                 'Saved to: data_files/pol_sample_itr_%02d.pkl.\n') %
                (N, itr, itr))

    def _initialize(self, itr_load):
        """
        Initialize from the specified iteration.
        Args:
            itr_load: If specified, loads algorithm state from that
                iteration, and resumes training at the next iteration.
        Returns:
            itr_start: Iteration to start from.
        """
        if itr_load is None:
            if self.gui:
                self.gui.set_status_text('Press \'go\' to begin.')
            return 0
        else:
            algorithm_file = self._data_files_dir + 'algorithm_itr_%02d.pkl' % itr_load
            self.algorithm = self.data_logger.unpickle(algorithm_file)
            if self.algorithm is None:
                print("Error: cannot find '%s.'" % algorithm_file)
                os._exit(
                    1
                )  # called instead of sys.exit(), since this is in a thread

            if self.gui:
                traj_sample_lists = self.data_logger.unpickle(
                    self._data_files_dir +
                    ('traj_sample_itr_%02d.pkl' % itr_load))
                pol_sample_lists = self.data_logger.unpickle(
                    self._data_files_dir +
                    ('pol_sample_itr_%02d.pkl' % itr_load))
                self.gui.update(itr_load, self.algorithm, self.agent,
                                traj_sample_lists, pol_sample_lists)
                self.gui.set_status_text((
                    'Resuming training from algorithm state at iteration %d.\n'
                    + 'Press \'go\' to begin.') % itr_load)
            return itr_load + 1

    def _take_sample(self, itr, cond, i):
        """
        Collect a sample from the agent.
        Args:
            itr: Iteration number.
            cond: Condition number.
            i: Sample number.
        Returns: None
        """
        pol = self.algorithm.cur[cond].traj_distr
        if self.gui:
            self.gui.set_image_overlays(cond)  # Must call for each new cond.
            redo = True
            while redo:
                while self.gui.mode in ('wait', 'request', 'process'):
                    if self.gui.mode in ('wait', 'process'):
                        time.sleep(0.01)
                        continue
                    # 'request' mode.
                    if self.gui.request == 'reset':
                        try:
                            self.agent.reset(cond)
                        except NotImplementedError:
                            self.gui.err_msg = 'Agent reset unimplemented.'
                    elif self.gui.request == 'fail':
                        self.gui.err_msg = 'Cannot fail before sampling.'
                    self.gui.process_mode()  # Complete request.

                self.gui.set_status_text(
                    'Sampling: iteration %d, condition %d, sample %d.' %
                    (itr, cond, i))
                self.agent.sample(
                    pol,
                    cond,
                    verbose=(i < self._hyperparams['verbose_trials']))

                if self.gui.mode == 'request' and self.gui.request == 'fail':
                    redo = True
                    self.gui.process_mode()
                    self.agent.delete_last_sample(cond)
                else:
                    redo = False
        else:
            self.agent.sample(
                pol, cond, verbose=(i < self._hyperparams['verbose_trials']))

    def _take_iteration(self, itr, sample_lists):
        """
        Take an iteration of the algorithm.
        Args:
            itr: Iteration number.
        Returns: None
        """
        if self.gui:
            self.gui.set_status_text('Calculating.')
            self.gui.start_display_calculating()
        self.algorithm.iteration(sample_lists)
        if self.gui:
            self.gui.stop_display_calculating()

    def _take_policy_samples(self, N=None):
        """
        Take samples from the policy to see how it's doing.
        Args:
            N  : number of policy samples to take per condition
        Returns: None
        """
        if 'verbose_policy_trials' not in self._hyperparams:
            return None
        if not N:
            N = self._hyperparams['verbose_policy_trials']
        if self.gui:
            self.gui.set_status_text('Taking policy samples.')
        pol_samples = [[None for _ in range(N)]
                       for _ in range(self._conditions)]
        for cond in range(self._conditions):
            for i in range(N):
                pol_samples[cond][i] = self.agent.sample(
                    self.algorithm.policy_opt.policy,
                    cond,
                    verbose=True,
                    save=False)
        return [SampleList(samples) for samples in pol_samples]

    def _log_data(self, itr, traj_sample_lists, pol_sample_lists=None):
        """
        Log data and algorithm, and update the GUI.
        Args:
            itr: Iteration number.
            traj_sample_lists: trajectory samples as SampleList object
            pol_sample_lists: policy samples as SampleList object
        Returns: None
        """
        if self.gui:
            self.gui.set_status_text('Logging data and updating GUI.')
            self.gui.update(itr, self.algorithm, self.agent, traj_sample_lists,
                            pol_sample_lists)
            self.gui.save_figure(self._data_files_dir +
                                 ('figure_itr_%02d.png' % itr))
        self.data_logger.pickle(
            self._data_files_dir + ('algorithm_itr_%02d.pkl' % itr),
            copy.copy(self.algorithm))
        self.data_logger.pickle(
            self._data_files_dir + ('traj_sample_itr_%02d.pkl' % itr),
            copy.copy(traj_sample_lists))
        if pol_sample_lists:
            self.data_logger.pickle(
                self._data_files_dir + ('pol_sample_itr_%02d.pkl' % itr),
                copy.copy(pol_sample_lists))

    def _end(self):
        """ Finish running and exit. """
        if self.gui:
            self.gui.set_status_text('Training complete.')
            self.gui.end_mode()
예제 #21
0
class GPSMain(object):
    """ Main class to run algorithms and experiments. """
    def __init__(self, config, quit_on_end=False):
        """
        Initialize GPSMain
        Args:
            config: Hyperparameters for experiment
            quit_on_end: When true, quit automatically on completion
        """
        self._quit_on_end = quit_on_end
        self._hyperparams = config
        self._conditions = config['common']['conditions']
        if 'train_conditions' in config['common']:
            self._train_idx = config['common']['train_conditions']
            self._test_idx = config['common']['test_conditions']
        else:
            self._train_idx = range(self._conditions)
            config['common']['train_conditions'] = config['common'][
                'conditions']
            self._hyperparams = config
            self._test_idx = self._train_idx

        self._data_files_dir = config['common']['data_files_dir']

        self.agent = config['agent']['type'](config['agent'])
        self.data_logger = DataLogger()
        self.gui = GPSTrainingGUI(
            config['common']) if config['gui_on'] else None

        config['algorithm']['agent'] = self.agent
        self.algorithm = config['algorithm']['type'](config['algorithm'])

        self.init_alpha(self)

    def run(self, config, itr_load=None):
        """
        Run training by iteratively sampling and taking an iteration.
        Args:
            itr_load: If specified, loads algorithm state from that
                iteration, and resumes training at the next iteration.
        Returns: None
        """
        self.target_points = self.agent._hyperparams['target_ee_points'][:3]
        itr_start = self._initialize(itr_load)

        # """ set pre"""
        # position_train = self.data_logger.unpickle('./position/position_train.pkl')
        """ generate random training position in a specify circle"""
        center_position = np.array([0.05, -0.08, 0])
        position_train = self.generate_position_radius(center_position, 0.08,
                                                       7, 0.02)

        print('training position.....')
        print(position_train)

        # print('test all testing position....')
        # for i in xrange(position_train.shape[0]):
        #     test_positions = self.generate_position_radius(position_train[i], 0.03, 5, 0.01)
        #     if i == 0:
        #         all_test_positions = test_positions
        #     else:
        #         all_test_positions = np.concatenate((all_test_positions, test_positions))

        T = self.algorithm.T
        N = self._hyperparams['num_samples']
        dU = self.algorithm.dU
        for num_pos in range(position_train.shape[0]):
            """ load train position and reset agent model. """
            for cond in self._train_idx:
                self._hyperparams['agent']['pos_body_offset'][
                    cond] = position_train[num_pos]
            self.agent.reset_model(self._hyperparams)

            # initial train array
            train_prc = np.zeros((0, T, dU, dU))
            train_mu = np.zeros((0, T, dU))
            train_obs_data = np.zeros((0, T, self.algorithm.dO))
            train_wt = np.zeros((0, T))

            # initial variables
            count_suc = 0

            for itr in range(itr_start, self._hyperparams['iterations']):
                print('******************num_pos:************', num_pos)
                print('______________________itr:____________', itr)
                for cond in self._train_idx:
                    for i in range(self._hyperparams['num_samples']):
                        if num_pos == 0:
                            self._take_sample(itr, cond, i)
                        elif itr == 0:
                            self._take_sample(itr, cond, i)
                        else:
                            self._take_train_sample(itr, cond, i)
                            # self._take_sample(itr, cond, i)

                traj_sample_lists = [
                    self.agent.get_samples(cond,
                                           -self._hyperparams['num_samples'])
                    for cond in self._train_idx
                ]

                # calculate the distance of  the end-effector to target position
                ee_pos = self.agent.get_ee_pos(cond)[:3]
                target_pos = self.agent._hyperparams['target_ee_pos'][:3]
                distance_pos = ee_pos - target_pos
                distance_ee = np.sqrt(distance_pos.dot(distance_pos))
                print('distance ee:', distance_ee)

                # collect the successful sample to train global policy
                if distance_ee <= 0.06:
                    count_suc += 1
                    tgt_mu, tgt_prc, obs_data, tgt_wt = self.train_prepare(
                        traj_sample_lists)
                    train_mu = np.concatenate((train_mu, tgt_mu))
                    train_prc = np.concatenate((train_prc, tgt_prc))
                    train_obs_data = np.concatenate((train_obs_data, obs_data))
                    train_wt = np.concatenate((train_wt, tgt_wt))

                # Clear agent samples.
                self.agent.clear_samples()

                # if get enough sample, then break
                if count_suc > 8:
                    break

                self._take_iteration(itr, traj_sample_lists)
                if self.algorithm.flag_reset:
                    break
                # pol_sample_lists = self._take_policy_samples()
                # self._log_data(itr, traj_sample_lists, pol_sample_lists)
                if num_pos > 0:
                    self.algorithm.fit_global_linear_policy(traj_sample_lists)

            if not self.algorithm.flag_reset:
                # train NN with good samples
                self.algorithm.policy_opt.update(train_obs_data, train_mu,
                                                 train_prc, train_wt)

                # test the trained in the current position
                print('test current policy.....')
                self.test_current_policy()
                print('test all testing position....')
                for i in xrange(position_train.shape[0]):
                    test_positions = self.generate_position_radius(
                        position_train[i], 0.03, 5, 0.01)
                    if i == 0:
                        all_test_positions = test_positions
                    else:
                        all_test_positions = np.concatenate(
                            (all_test_positions, test_positions))
                self.test_cost(all_test_positions)

            # reset the algorithm to the initial algorithm for the next position
            # del self.algorithm
            # config['algorithm']['agent'] = self.agent
            # self.algorithm = config['algorithm']['type'](config['algorithm'])
            self.algorithm.reset_alg()
            self.next_iteration_prepare()

        self._end()

    def generate_position_radius(self, position_ori, radius, conditions,
                                 max_error_bound):
        """

        Args:
            position_ori: original center position of generated positions
            radius:     area's radius
            conditions: the quantity of generating positions
            max_error_bound: the mean of generated positions' error around cposition

        Returns:

        """
        c_x = position_ori[0]
        c_y = position_ori[1]
        while True:
            all_positions = np.zeros(0)
            center_position = np.array([c_x, c_y, 0])
            for i in range(conditions):
                position = np.random.uniform(radius, radius, 3)
                while True:
                    position[2] = 0
                    position[1] = (position[1] + c_y)
                    position[0] = position[0] + c_x
                    area = (position - center_position).dot(position -
                                                            center_position)
                    if area <= (np.pi * radius**2) / 4.0:
                        break
                    position = np.random.uniform(-radius, radius, 3)
                if i == 0:
                    all_positions = position
                    all_positions = np.expand_dims(all_positions, axis=0)
                else:
                    all_positions = np.vstack((all_positions, position))

            mean_position = np.mean(all_positions, axis=0)
            mean_error = np.fabs(center_position - mean_position)
            print('mean_error:', mean_error)
            if mean_error[0] < max_error_bound and mean_error[
                    1] < max_error_bound:
                break

        all_positions = np.floor(all_positions * 1000) / 1000.0
        print('all_position:', all_positions)
        return all_positions

    def test_cost(self, position):
        """
        test the NN policy at all position
        Args:
            position:

        Returns:

        """
        total_costs = np.zeros(0)
        total_distance = np.zeros(0)
        total_suc = np.zeros(0)
        print 'calculate cost_________________'
        for itr in range(position.shape[0]):
            if itr % 51 == 0:
                print('****************')
            for cond in self._train_idx:
                self._hyperparams['agent']['pos_body_offset'][cond] = position[
                    itr]
            self.agent.reset_model(self._hyperparams)
            _, cost, ee_points = self.take_nn_samples()
            ee_error = ee_points[:3] - self.target_points
            distance = np.sqrt(ee_error.dot(ee_error))
            error = np.sum(np.fabs(ee_error))
            if (error < 0.02):
                total_suc = np.concatenate((total_suc, np.array([1])))
            else:
                total_suc = np.concatenate((total_suc, np.array([0])))
            total_costs = np.concatenate((total_costs, np.array(cost)))
            total_distance = np.concatenate(
                (total_distance, np.array([distance])))
        # return np.mean(total_costs), total_suc, total_distance
        return total_costs, total_suc, total_distance

    def next_iteration_prepare(self):
        """
        prepare for the next iteration
        Returns:

        """
        self.init_alpha()

    def init_alpha(self, val=None):
        """
        initialize the alpha1, 2, the default is 0.7, 0.3
        Args:
            val:

        Returns:

        """
        if val is None:
            self.alpha1 = 0.75
            self.alpha2 = 0.25
        else:
            self.alpha1 = 0.75
            self.alpha2 = 0.25

    def pol_alpha(self):
        return self.alpha1, self.alpha2

    def train_prepare(self, sample_lists):
        """
        prepare the train data of the sample lists
        Args:
            sample_lists: sample list from agent

        Returns:
            target mu, prc, obs_data, wt

        """
        algorithm = self.algorithm
        dU, dO, T = algorithm.dU, algorithm.dO, algorithm.T
        obs_data, tgt_mu = np.zeros((0, T, dO)), np.zeros((0, T, dU))
        tgt_prc = np.zeros((0, T, dU, dU))
        tgt_wt = np.zeros((0, T))
        wt_origin = 0.01 * np.ones(T)
        for m in range(algorithm.M):
            samples = sample_lists[m]
            X = samples.get_X()
            N = len(samples)
            prc = np.zeros((N, T, dU, dU))
            mu = np.zeros((N, T, dU))
            wt = np.zeros((N, T))

            traj = algorithm.cur[m].traj_distr
            for t in range(T):
                prc[:, t, :, :] = np.tile(traj.inv_pol_covar[t, :, :],
                                          [N, 1, 1])
                for i in range(N):
                    mu[i,
                       t, :] = (traj.K[t, :, :].dot(X[i, t, :]) + traj.k[t, :])
                wt[:, t].fill(wt_origin[t])
            tgt_mu = np.concatenate((tgt_mu, mu))
            tgt_prc = np.concatenate((tgt_prc, prc))
            obs_data = np.concatenate((obs_data, samples.get_obs()))
            tgt_wt = np.concatenate((tgt_wt, wt))

        return tgt_mu, tgt_prc, obs_data, tgt_wt

    def test_policy(self, itr, N):
        """
        Take N policy samples of the algorithm state at iteration itr,
        for testing the policy to see how it is behaving.
        (Called directly from the command line --policy flag).
        Args:
            itr: the iteration from which to take policy samples
            N: the number of policy samples to take
        Returns: None
        """
        algorithm_file = self._data_files_dir + 'algorithm_itr_%02d.pkl' % itr
        self.algorithm = self.data_logger.unpickle(algorithm_file)
        if self.algorithm is None:
            print("Error: cannot find '%s.'" % algorithm_file)
            os._exit(1)  # called instead of sys.exit(), since t
        traj_sample_lists = self.data_logger.unpickle(
            self._data_files_dir + ('traj_sample_itr_%02d.pkl' % itr))

        pol_sample_lists = self._take_policy_samples(N)
        self.data_logger.pickle(
            self._data_files_dir + ('pol_sample_itr_%02d.pkl' % itr),
            copy.copy(pol_sample_lists))

        if self.gui:
            self.gui.update(itr, self.algorithm, self.agent, traj_sample_lists,
                            pol_sample_lists)
            self.gui.set_status_text(
                ('Took %d policy sample(s) from ' +
                 'algorithm state at iteration %d.\n' +
                 'Saved to: data_files/pol_sample_itr_%02d.pkl.\n') %
                (N, itr, itr))

    def _initialize(self, itr_load):
        """
        Initialize from the specified iteration.
        Args:
            itr_load: If specified, loads algorithm state from that
                iteration, and resumes training at the next iteration.
        Returns:
            itr_start: Iteration to start from.
        """
        if itr_load is None:
            if self.gui:
                self.gui.set_status_text('Press \'go\' to begin.')
            return 0
        else:
            algorithm_file = self._data_files_dir + 'algorithm_itr_%02d.pkl' % itr_load
            self.algorithm = self.data_logger.unpickle(algorithm_file)
            if self.algorithm is None:
                print("Error: cannot find '%s.'" % algorithm_file)
                os._exit(
                    1
                )  # called instead of sys.exit(), since this is in a thread

            if self.gui:
                traj_sample_lists = self.data_logger.unpickle(
                    self._data_files_dir +
                    ('traj_sample_itr_%02d.pkl' % itr_load))
                if self.algorithm.cur[0].pol_info:
                    pol_sample_lists = self.data_logger.unpickle(
                        self._data_files_dir +
                        ('pol_sample_itr_%02d.pkl' % itr_load))
                else:
                    pol_sample_lists = None
                self.gui.set_status_text((
                    'Resuming training from algorithm state at iteration %d.\n'
                    + 'Press \'go\' to begin.') % itr_load)
            return itr_load + 1

    def _take_sample(self, itr, cond, i):
        """
        Collect a sample from the agent.
        Args:
            itr: Iteration number.
            cond: Condition number.
            i: Sample number.
        Returns: None
        """
        if self.algorithm._hyperparams['sample_on_policy'] \
                and self.algorithm.iteration_count > 0:
            pol = self.algorithm.policy_opt.policy
        else:
            pol = self.algorithm.cur[cond].traj_distr
        if self.gui:
            self.gui.set_image_overlays(cond)  # Must call for each new cond.
            redo = True
            while redo:
                while self.gui.mode in ('wait', 'request', 'process'):
                    if self.gui.mode in ('wait', 'process'):
                        time.sleep(0.01)
                        continue
                    # 'request' mode.
                    if self.gui.request == 'reset':
                        try:
                            self.agent.reset(cond)
                        except NotImplementedError:
                            self.gui.err_msg = 'Agent reset unimplemented.'
                    elif self.gui.request == 'fail':
                        self.gui.err_msg = 'Cannot fail before sampling.'
                    self.gui.process_mode()  # Complete request.

                self.gui.set_status_text(
                    'Sampling: iteration %d, condition %d, sample %d.' %
                    (itr, cond, i))
                self.agent.sample(
                    pol,
                    cond,
                    verbose=(i < self._hyperparams['verbose_trials']))

                if self.gui.mode == 'request' and self.gui.request == 'fail':
                    redo = True
                    self.gui.process_mode()
                    self.agent.delete_last_sample(cond)
                else:
                    redo = False
        else:
            self.agent.sample(
                pol, cond, verbose=(i < self._hyperparams['verbose_trials']))

    def _take_train_sample(self, itr, cond, i):
        """
        collect sample with merge policy
        Args:
            itr:
            cond:
            i:

        Returns:

        """
        alpha1, alpha2 = self.pol_alpha()
        print("alpha:********%03f, %03f******" % (alpha1, alpha2))
        pol1 = self.algorithm.cur[cond].traj_distr
        pol2 = self.algorithm.cur[cond].last_pol
        if not self.gui:
            self.agent.merge_controller(
                pol1,
                alpha1,
                pol2,
                alpha2,
                cond,
                verbose=(i < self._hyperparams['verbose_trials']))

    def _take_iteration(self, itr, sample_lists):
        """
        Take an iteration of the algorithm.
        Args:
            itr: Iteration number.
        Returns: None
        """
        if self.gui:
            self.gui.set_status_text('Calculating.')
            self.gui.start_display_calculating()
        self.algorithm.iteration(sample_lists)
        if self.gui:
            self.gui.stop_display_calculating()

    def _take_policy_samples(self, N=None):
        """
        Take samples from the policy to see how it's doing.
        Args:
            N  : number of policy samples to take per condition
        Returns: None
        """
        if 'verbose_policy_trials' not in self._hyperparams:
            # AlgorithmTrajOpt
            return None
        verbose = self._hyperparams['verbose_policy_trials']
        if self.gui:
            self.gui.set_status_text('Taking policy samples.')
        pol_samples = [[None] for _ in range(len(self._test_idx))]
        # Since this isn't noisy, just take one sample.
        # TODO: Make this noisy? Add hyperparam?
        # TODO: Take at all conditions for GUI?
        for cond in range(len(self._test_idx)):
            pol_samples[cond][0] = self.agent.sample(
                self.algorithm.policy_opt.policy,
                self._test_idx[cond],
                verbose=verbose,
                save=False,
                noisy=False)
        return [SampleList(samples) for samples in pol_samples]

    def take_nn_samples(self, N=None):
        """
        take the NN policy
        Args:
            N:

        Returns:
            samples, costs, ee_points

        """
        """
            Take samples from the policy to see how it's doing.
            Args:
                N  : number of policy samples to take per condition
            Returns: None
            """

        if 'verbose_policy_trials' not in self._hyperparams:
            # AlgorithmTrajOpt
            return None
        verbose = self._hyperparams['verbose_policy_trials']
        if self.gui:
            self.gui.set_status_text('Taking policy samples.')
        pol_samples = [[None] for _ in range(len(self._test_idx))]
        # Since this isn't noisy, just take one sample.
        # TODO: Make this noisy? Add hyperparam?
        # TODO: Take at all conditions for GUI?
        costs = list()
        for cond in range(len(self._test_idx)):
            pol_samples[cond][0] = self.agent.sample(
                self.algorithm.policy_opt.policy,
                self._test_idx[cond],
                verbose=verbose,
                save=False,
                noisy=False)
            policy_cost = self.algorithm.cost[0].eval(pol_samples[cond][0])[0]
            policy_cost = np.sum(policy_cost)
            print "cost: %d" % policy_cost  # wait to plot in gui in gps_training_gui.py
            costs.append(policy_cost)

            ee_points = self.agent.get_ee_point(cond)

        return [SampleList(samples)
                for samples in pol_samples], costs, ee_points

    def test_current_policy(self):
        """
        test the current NN policy in the current position
        Returns:

        """
        verbose = self._hyperparams['verbose_policy_trials']
        for cond in self._train_idx:
            samples = self.agent.sample(self.algorithm.policy_opt.policy,
                                        cond,
                                        verbose=verbose,
                                        save=False,
                                        noisy=False)

    def _log_data(self, itr, traj_sample_lists, pol_sample_lists=None):
        """
        Log data and algorithm, and update the GUI.
        Args:
            itr: Iteration number.
            traj_sample_lists: trajectory samples as SampleList object
            pol_sample_lists: policy samples as SampleList object
        Returns: None
        """
        if self.gui:
            self.gui.set_status_text('Logging data and updating GUI.')
            self.gui.update(itr, self.algorithm, self.agent, traj_sample_lists,
                            pol_sample_lists)
            self.gui.save_figure(self._data_files_dir +
                                 ('figure_itr_%02d.png' % itr))
        if 'no_sample_logging' in self._hyperparams['common']:
            return
        self.data_logger.pickle(
            self._data_files_dir + ('algorithm_itr_%02d.pkl' % itr),
            copy.copy(self.algorithm))
        self.data_logger.pickle(
            self._data_files_dir + ('traj_sample_itr_%02d.pkl' % itr),
            copy.copy(traj_sample_lists))
        if pol_sample_lists:
            self.data_logger.pickle(
                self._data_files_dir + ('pol_sample_itr_%02d.pkl' % itr),
                copy.copy(pol_sample_lists))

    def _end(self):
        """ Finish running and exit. """
        if self.gui:
            self.gui.set_status_text('Training complete.')
            self.gui.end_mode()
            if self._quit_on_end:
                # Quit automatically (for running sequential expts)
                os._exit(1)
예제 #22
0
class GPSMain(object):
    """ Main class to run algorithms and experiments. """
    def __init__(self, config, quit_on_end=False):
        """
        Initialize GPSMain
        Args:
            config: Hyperparameters for experiment
            quit_on_end: When true, quit automatically on completion
        """
        self._quit_on_end = quit_on_end
        self._hyperparams = config
        self._conditions = config['common']['conditions']
        if 'train_conditions' in config['common']:
            self._train_idx = config['common']['train_conditions']
            self._test_idx = config['common']['test_conditions']
        else:
            self._train_idx = range(self._conditions)
            config['common']['train_conditions'] = config['common']['conditions']
            self._hyperparams=config
            self._test_idx = self._train_idx

        self._data_files_dir = config['common']['data_files_dir']

        self.agent = config['agent']['type'](config['agent'])
        self.data_logger = DataLogger()
        self.gui = GPSTrainingGUI(config['common']) if config['gui_on'] else None
        
        #self.use_prev_dyn = True # Modified dynamics; Alisher ,
        config['algorithm']['agent'] = self.agent
        print '__init__ is called ##########'
        self.algorithm = config['algorithm']['type'](config['algorithm'])

    def run(self, itr_load=None):
        """
        Run training by iteratively sampling and taking an iteration.
        Args:
            itr_load: If specified, loads algorithm state from that
                iteration, and resumes training at the next iteration.
        Returns: None
        """
        try:
            itr_start = self._initialize(itr_load)

            self.agent.open_camera()
            for itr in range(itr_start, self._hyperparams['iterations']):
                for cond in self._train_idx:
                    for i in range(self._hyperparams['num_samples']):
                        self._take_sample(itr, cond, i)
                    print('Current condition: {0}/{1}'.format(cond,self._train_idx))
                    raw_input('Press Enter to continue next condition...')

                traj_sample_lists = [
                    self.agent.get_samples(cond, -self._hyperparams['num_samples'])
                    for cond in self._train_idx
                ]

                # Clear agent samples.
                self.agent.clear_samples()

                self._take_iteration(itr, traj_sample_lists)
                pol_sample_lists = self._take_policy_samples(itr)
                self._log_data(itr, traj_sample_lists, pol_sample_lists)
        except Exception as e:
            traceback.print_exception(*sys.exc_info())
        finally:
            self._end()

    def test_policy(self, itr, N):
        """
        Take N policy samples of the algorithm state at iteration itr,
        for testing the policy to see how it is behaving.
        (Called directly from the command line --policy flag).
        Args:
            itr: the iteration from which to take policy samples
            N: the number of policy samples to take
        Returns: None
        """
        itr = 10
        algorithm_file = self._data_files_dir + 'algorithm_itr_%02d.pkl' % itr
        self.algorithm = self.data_logger.unpickle(algorithm_file)
        
        ##
        self.algorithm.policy_opt.main_itr = itr
        ##
        if self.algorithm is None:
            print("Error: cannot find '%s.'" % algorithm_file)
            os._exit(1) # called instead of sys.exit(), since t
        traj_sample_lists = self.data_logger.unpickle(self._data_files_dir +
            ('traj_sample_itr_%02d.pkl' % itr))

        raw_input('Press Enter to continue...')
        pol_sample_lists = self._take_policy_samples(itr, N)
        '''
        self.data_logger.pickle(
            self._data_files_dir + ('pol_sample_itr_%02d.pkl' % itr),
            copy.copy(pol_sample_lists)
        )
        '''
        if self.gui:
            self.gui.update(itr, self.algorithm, self.agent,
                traj_sample_lists, pol_sample_lists)
            self.gui.set_status_text(('Took %d policy sample(s) from ' +
                'algorithm state at iteration %d.\n' +
                'Saved to: data_files/pol_sample_itr_%02d.pkl.\n') % (N, itr, itr))
        os._exit(1)

    def _initialize(self, itr_load):
        """
        Initialize from the specified iteration.
        Args:
            itr_load: If specified, loads algorithm state from that
                iteration, and resumes training at the next iteration.
        Returns:
            itr_start: Iteration to start from.
        """
        if itr_load is None:
            if self.gui:
                self.gui.set_status_text('Press \'go\' to begin.')
            return 0
        else:
            algorithm_file = self._data_files_dir + 'algorithm_itr_%02d.pkl' % itr_load
            self.algorithm = self.data_logger.unpickle(algorithm_file)
            if self.algorithm is None:
                print("Error: cannot find '%s.'" % algorithm_file)
                os._exit(1) # called instead of sys.exit(), since this is in a thread

            if self.gui:
                traj_sample_lists = self.data_logger.unpickle(self._data_files_dir +
                     ('traj_sample_itr_%02d.pkl' % itr_load))
                if self.algorithm.cur[0].pol_info:
                    pol_sample_lists = self.data_logger.unpickle(self._data_files_dir +
                        ('pol_sample_itr_%02d.pkl' % itr_load))
                else:
                    pol_sample_lists = None
                self.gui.set_status_text(
                    ('Resuming training from algorithm state at iteration %d.\n' +
                    'Press \'go\' to begin.') % itr_load)
            return itr_load + 1

    def _take_sample(self, itr, cond, i):
        """
        Collect a sample from the agent.
        Args:
            itr: Iteration number.
            cond: Condition number.
            i: Sample number.
        Returns: None
        """
        if self.algorithm._hyperparams['sample_on_policy'] \
                and self.algorithm.iteration_count > 0:
            pol = self.algorithm.policy_opt.policy
        else:
            pol = self.algorithm.cur[cond].traj_distr
        ##
        self.algorithm.policy_opt.main_itr = itr
        ##
        if self.gui:
            self.gui.set_image_overlays(cond)   # Must call for each new cond.
            redo = True
            while redo:
                while self.gui.mode in ('wait', 'request', 'process'):
                    if self.gui.mode in ('wait', 'process'):
                        time.sleep(0.01)
                        continue
                    # 'request' mode.
                    if self.gui.request == 'reset':
                        try:
                            self.agent.reset(cond)
                        except NotImplementedError:
                            self.gui.err_msg = 'Agent reset unimplemented.'
                    elif self.gui.request == 'fail':
                        self.gui.err_msg = 'Cannot fail before sampling.'
                    self.gui.process_mode()  # Complete request.

                self.gui.set_status_text(
                    'Sampling: iteration %d, condition %d, sample %d.' %
                    (itr, cond, i)
                )
                print '\nCondition: %d, sample: %d' % (cond, i)
                self.agent.sample(itr,
                    pol, cond,
                    verbose=(i < self._hyperparams['verbose_trials'])
                )

                if self.gui.mode == 'request' and self.gui.request == 'fail':
                    redo = True
                    self.gui.process_mode()
                    self.agent.delete_last_sample(cond)
                else:
                    redo = False
        else:
            self.agent.sample(itr, 
                pol, cond,
                verbose=(i < self._hyperparams['verbose_trials'])
            )

    def _take_iteration(self, itr, sample_lists):
        """
        Take an iteration of the algorithm.
        Args:
            itr: Iteration number.
        Returns: None
        """
        if self.gui:
            self.gui.set_status_text('Calculating.')
            self.gui.start_display_calculating()
        self.algorithm.iteration(itr, sample_lists)
        if self.gui:
            self.gui.stop_display_calculating()

    def _take_policy_samples(self, itr, N=None):
        """
        Take samples from the policy to see how it's doing.
        Args:
            itr : To pass iteration to next function, can erase when it is not neccessary.
            N  : number of policy samples to take per condition
        Returns: None
        """
        
        if 'verbose_policy_trials' not in self._hyperparams:
            # AlgorithmTrajOpt
            return None
        verbose = self._hyperparams['verbose_policy_trials']
        if self.gui:
            self.gui.set_status_text('Taking policy samples.')
        pol_samples = [[None] for _ in range(len(self._test_idx))]
        # Since this isn't noisy, just take one sample.
        # TODO: Make this noisy? Add hyperparam?
        # TODO: Take at all conditions for GUI?
        
        '''
        self.agent.init_gripper()
        self.agent.calibrate_gripper()
        self.agent.open_gripper()
        '''

        if type(N) == type(None):
            N = 1
        for num_itr in range(N):
            for cond in range(len(self._test_idx)):
                pol_samples[cond][0] = self.agent.sample(
                    itr, self.algorithm.policy_opt.policy, self._test_idx[cond],
                    verbose=verbose, save=False, noisy=False)
                '''
                self.agent.close_gripper()
                time.sleep(1)
                self.agent.move_to_position([0.7811797153381348, -0.7014127144592286, 0.2304806131164551, 1.2939127931030274, 0.04908738515625, 0.7696748594421388, 0.32021848910522466])
                time.sleep(1)
                self.agent.move_to_position([0.6906748489562988, -0.33555829696655276, 0.10392719826049805, 1.2417574463745118, -0.050237870745849615, 0.5273058952331543, 0.06327670742797852])
                time.sleep(1)
                self.agent.open_gripper()
                '''
                print('Current condition: {0}/{1}'.format(cond,self._train_idx))
                raw_input('Press Enter to continue next condition...')
        
        ''' kiro project assembly
        if 'verbose_policy_trials' not in self._hyperparams:
            # AlgorithmTrajOpt
            return None
        verbose = self._hyperparams['verbose_policy_trials']
        if self.gui:
            self.gui.set_status_text('Taking policy samples.')
        pol_samples = [[None] for _ in range(len(self._test_idx))]

        self.agent.init_gripper()
        self.agent.calibrate_gripper()
        self.agent.open_gripper()

        right_put = [-0.14879613625488283, -0.487038899597168, 0.9096506061767579, 1.1926700612182617, -0.8218302061706544, 1.2524953118774416, 1.1355292769348144]

        if type(N) == type(None):
            N = 3
        t1 = time.time()
        for num_itr in range(N):
            t3 = time.time()
            self.agent.move_left_to_position([-0.5223204576782227, 0.5016117170654297, -2.212767283996582, -0.051771851531982424, 0.4375680192443848, 2.0432624071289065, 0.20286895896606447])
            self.agent.move_to_position(right_put)
            raw_input('Press enter to start and close gripper...')
            print 'sample number : %d' % num_itr
            self.agent.close_gripper()
            time.sleep(1)        
            for cond in range(len(self._test_idx)):
                pol_samples[cond][0] = self.agent.sample(
                    itr, self.algorithm.policy_opt.policy, self._test_idx[cond],
                    verbose=verbose, save=False, noisy=False)

                self.agent.move_to_position([0.11965050131835939, 0.09549030393676758, 1.606461378277588, 1.2137622970275879, -0.5457136646667481, 1.030451593084717, 0.1721893432434082])
                self.agent.open_gripper()
                time.sleep(0.5)

                # self.agent.move_to_position([0.16375244891967775, -0.17832526638793947, 1.262849682183838, 1.0630486847900391, -0.012655341485595705, 1.16850986383667, 0.2174417764343262])
                self.agent.move_to_position([-0.10239321747436524, -0.28723790220336914, 1.2858593939758303, 0.9840486743041993, 0.051004861138916016, 1.502917675213623, 0.15531555459594729])
                self.agent.move_left_to_position([-0.5223204576782227, 0.5016117170654297, -2.212767283996582, -0.051771851531982424, 0.4375680192443848, 2.0432624071289065, 0.20286895896606447])
                # self.agent.move_left_to_position(self._hyperparams['initial_left_arm'])
            t4 = time.time()
            print '%d m %d s' % ((t4-t3)/60, (t4-t3)%60)
        t2 = time.time()
        print '%d h %d m %d s' % ((t2 - t1)/3600, (t2-t1)/60, (t2-t1)%60)
        return [SampleList(samples) for samples in pol_samples]
        '''
    def _log_data(self, itr, traj_sample_lists, pol_sample_lists=None):
        """
        Log data and algorithm, and update the GUI.
        Args:
            itr: Iteration number.
            traj_sample_lists: trajectory samples as SampleList object
            pol_sample_lists: policy samples as SampleList object
        Returns: None
        """
        if self.gui:
            self.gui.set_status_text('Logging data and updating GUI.')
            self.gui.update(itr, self.algorithm, self.agent,
                traj_sample_lists, pol_sample_lists)
            self.gui.save_figure(
                self._data_files_dir + ('figure_itr_%02d.png' % itr)
            )
        if 'no_sample_logging' in self._hyperparams['common']:
            return
        '''
        if ((itr+1) % 5 == 0 or (itr+1) % 3 == 0):
            self.data_logger.pickle(
                self._data_files_dir + ('algorithm_itr_%02d.pkl' % itr),
                copy.copy(self.algorithm)
            )
            self.data_logger.pickle(
                self._data_files_dir + ('traj_sample_itr_%02d.pkl' % itr),
                copy.copy(traj_sample_lists)
            )
            if pol_sample_lists:
                self.data_logger.pickle(
                    self._data_files_dir + ('pol_sample_itr_%02d.pkl' % itr),
                    copy.copy(pol_sample_lists)
                )
        '''
        save = raw_input('Do you want to save pickle file of policy at itr %02d? [y/n]: ' % itr)
        if save == 'y':
            self.data_logger.pickle(
                self._data_files_dir + ('algorithm_itr_%02d.pkl' % itr),
                copy.copy(self.algorithm)
            )
            self.data_logger.pickle(
                self._data_files_dir + ('traj_sample_itr_%02d.pkl' % itr),
                copy.copy(traj_sample_lists)
            )
            if pol_sample_lists:
                self.data_logger.pickle(
                    self._data_files_dir + ('pol_sample_itr_%02d.pkl' % itr),
                    copy.copy(pol_sample_lists)
                )

    def _end(self):
        """ Finish running and exit. """
        if self.gui:
            self.gui.set_status_text('Training complete.')
            self.gui.end_mode()
            if self._quit_on_end:
                # Quit automatically (for running sequential expts)
                os._exit(1)

    def save_local_controller(self):
        itr = 10
        algorithm_file = self._data_files_dir + 'algorithm_itr_%02d.pkl' % itr
        self.algorithm = self.data_logger.unpickle(algorithm_file)
        
        ## save
        for m in range(2):
            self.algorithm.cur[m].pol_info = None
            self.algorithm.cur[m].sample_list = None
            self.algorithm.cur[m].new_traj_distr = None

        path ='/hdd/gps-master/experiments/prepare_controller/data_files/'
        self.data_logger.pickle(path + 'dynamcis.pkl', copy.copy(self.algorithm.cur))
        os._exit(1)
class GPSMain(object):
    """ Main class to run tensorflow_code-pytorch and experiments. """
    def __init__(self, config, quit_on_end=False):
        """
        Initialize GPSMain
        Args:
            config: Hyperparameters for experiment
            quit_on_end: When true, quit automatically on completion
        """

        self._quit_on_end = quit_on_end
        self._hyperparams = config
        self._conditions = config['common']['conditions']

        if 'train_conditions' in config['common']:
            self._train_idx = config['common']['train_conditions']
            self._test_idx = config['common']['test_conditions']
        else:
            self._train_idx = range(self._conditions)
            config['common']['train_conditions'] = config['common'][
                'conditions']
            self._hyperparams = config
            self._test_idx = self._train_idx

        self._data_files_dir = config['common']['data_files_dir']

        self.agent = config['agent']['type'](config['agent'])

        self.data_logger = DataLogger()

        self.gui = GPSTrainingGUI(
            config['common']) if config['gui_on'] else None

        config['algorithm']['agent'] = self.agent
        self.algorithm = config['algorithm']['type'](config['algorithm'])

    def run(self, itr_load=None):
        """
        Run training by iteratively sampling and taking an iteration.
        Args:
            itr_load: If specified, loads algorithm state from that
                iteration, and resumes training at the next iteration.
        Returns: None
        """
        try:
            itr_start = self._initialize(itr_load)
            for itr in range(itr_start, self._hyperparams['iterations']):
                """ get samples """
                for cond in self._train_idx:
                    for i in range(self._hyperparams['num_samples']):
                        self._take_sample(itr, cond, i)

                traj_sample_lists = [
                    self.agent.get_samples(cond,
                                           -self._hyperparams['num_samples'])
                    for cond in self._train_idx
                ]
                """ Clear agent samples """
                self.agent.clear_samples()
                """ interation """
                self._take_iteration(itr, traj_sample_lists)
                """ test policy and samples """
                pol_sample_lists = self._take_policy_samples()
                self._log_data(itr, traj_sample_lists, pol_sample_lists)

        except Exception as e:
            traceback.print_exception(*sys.exc_info())

        finally:
            self._end()

    def test_policy(self, itr, N):
        """
        Take N policy samples of the algorithm state at iteration itr,
        for testing the policy to see how it is behaving.
        (Called directly from the command line --policy flag).
        Args:
            itr: the iteration from which to take policy samples
            N: the number of policy samples to take
        Returns: None
        """
        algorithm_file = self._data_files_dir + 'algorithm_itr_%02d.pkl' % itr
        self.algorithm = self.data_logger.unpickle(algorithm_file)

        if self.algorithm is None:
            print("Error: cannot find '%s.'" % algorithm_file)
            os._exit(1)
            # called instead of sys.exit(), since t
        traj_sample_lists = self.data_logger.unpickle(
            self._data_files_dir + ('traj_sample_itr_%02d.pkl' % itr))

        pol_sample_lists = self._take_policy_samples(N)

        self.data_logger.pickle(
            self._data_files_dir + ('pol_sample_itr_%02d.pkl' % itr),
            copy.copy(pol_sample_lists))

        if self.gui:
            self.gui.update(itr, self.algorithm, self.agent, traj_sample_lists,
                            pol_sample_lists)
            self.gui.set_status_text(
                ('Took %d policy sample(s) from ' +
                 'algorithm state at iteration %d.\n' +
                 'Saved to: data_files/pol_sample_itr_%02d.pkl.\n') %
                (N, itr, itr))

    def _initialize(self, itr_load):
        """
        Initialize from the specified iteration.
        Args:
            itr_load: If specified, loads algorithm state from that
                iteration, and resumes training at the next iteration.
        Returns:
            itr_start: Iteration to start from.
        """
        if itr_load is None:
            if self.gui:
                self.gui.set_status_text('Press \'go\' to begin.')
            return 0
        else:
            algorithm_file = self._data_files_dir + 'algorithm_itr_%02d.pkl' % itr_load
            self.algorithm = self.data_logger.unpickle(algorithm_file)
            if self.algorithm is None:
                print("Error: cannot find '%s.'" % algorithm_file)
                os._exit(
                    1
                )  # called instead of sys.exit(), since this is in a thread

            if self.gui:
                traj_sample_lists = self.data_logger.unpickle(
                    self._data_files_dir +
                    ('traj_sample_itr_%02d.pkl' % itr_load))
                if self.algorithm.cur[0].pol_info:
                    pol_sample_lists = self.data_logger.unpickle(
                        self._data_files_dir +
                        ('pol_sample_itr_%02d.pkl' % itr_load))
                else:
                    pol_sample_lists = None
                self.gui.set_status_text((
                    'Resuming training from algorithm state at iteration %d.\n'
                    + 'Press \'go\' to begin.') % itr_load)
            return itr_load + 1

    def _take_sample(self, itr, cond, i):
        """
        Collect a sample from the agent.
        Args:
            itr: Iteration number.
            cond: Condition number.
            i: Sample number.
        Returns: None
        """
        if self.algorithm._hyperparams['sample_on_policy'] \
                and self.algorithm.iteration_count > 0:
            pol = self.algorithm.policy_opt.policy
            print(" ========================== on policy ====================")
        else:
            pol = self.algorithm.cur[cond].traj_distr

        if self.gui:
            self.gui.set_image_overlays(cond)  # Must call for each new cond.
            redo = True
            while redo:
                while self.gui.mode in ('wait', 'request', 'process'):
                    if self.gui.mode in ('wait', 'process'):
                        time.sleep(0.01)
                        continue
                    # 'request' mode.
                    if self.gui.request == 'reset':
                        try:
                            self.agent.reset(cond)
                        except NotImplementedError:
                            self.gui.err_msg = 'Agent reset unimplemented.'
                    elif self.gui.request == 'fail':
                        self.gui.err_msg = 'Cannot fail before sampling.'
                    self.gui.process_mode()  # Complete request.

                self.gui.set_status_text(
                    'Sampling: iteration %d, condition %d, sample %d.' %
                    (itr, cond, i))
                self.agent.sample(
                    pol,
                    cond,
                    verbose=(i < self._hyperparams['verbose_trials']))

                if self.gui.mode == 'request' and self.gui.request == 'fail':
                    redo = True
                    self.gui.process_mode()
                    self.agent.delete_last_sample(cond)
                else:
                    redo = False
        else:
            self.agent.sample(
                pol, cond, verbose=(i < self._hyperparams['verbose_trials']))

    def _take_iteration(self, itr, sample_lists):
        """
        Take an iteration of the algorithm.
        Args:
            itr: Iteration number.
        Returns: None
        """
        if self.gui:
            self.gui.set_status_text('Calculating.')
            self.gui.start_display_calculating()

        self.algorithm.iteration(sample_lists)

        if self.gui:
            self.gui.stop_display_calculating()

    def _take_policy_samples(self, N=None):
        """
        Take samples from the policy to see how it's doing.
        Args:
            N  : number of policy samples to take per condition
        Returns: None
        """
        print(
            " ================================ test policy ===================================="
        )
        if 'verbose_policy_trials' not in self._hyperparams:
            # AlgorithmTrajOpt
            return None
        verbose = self._hyperparams['verbose_policy_trials']

        if self.gui:
            self.gui.set_status_text('Taking policy samples.')

        pol_samples = [[None] for _ in range(len(self._test_idx))]
        # Since this isn't noisy, just take one sample.
        # TODO: Make this noisy? Add hyperparam?
        # TODO: Take at all conditions for GUI?
        for cond in range(len(self._test_idx)):
            pol_samples[cond][0] = self.agent.sample(
                self.algorithm.policy_opt.policy,
                self._test_idx[cond],
                verbose=verbose,
                save=False,
                noisy=False)

        return [SampleList(samples) for samples in pol_samples]

    def _log_data(self, itr, traj_sample_lists, pol_sample_lists=None):
        """
        Log data and algorithm, and update the GUI.
        Args:
            itr: Iteration number.
            traj_sample_lists: trajectory samples as SampleList object
            pol_sample_lists: policy samples as SampleList object
        Returns: None
        """
        if self.gui:
            self.gui.set_status_text('Logging data and updating GUI.')
            self.gui.update(itr, self.algorithm, self.agent, traj_sample_lists,
                            pol_sample_lists)
            self.gui.save_figure(self._data_files_dir +
                                 ('figure_itr_%02d.png' % itr))
        if 'no_sample_logging' in self._hyperparams['common']:
            return
        self.data_logger.pickle(
            self._data_files_dir + ('algorithm_itr_%02d.pkl' % itr),
            copy.copy(self.algorithm))
        self.data_logger.pickle(
            self._data_files_dir + ('traj_sample_itr_%02d.pkl' % itr),
            copy.copy(traj_sample_lists))
        if pol_sample_lists:
            self.data_logger.pickle(
                self._data_files_dir + ('pol_sample_itr_%02d.pkl' % itr),
                copy.copy(pol_sample_lists))

    def _end(self):
        """ Finish running and exit. """
        if self.gui:
            self.gui.set_status_text('Training complete.')
            self.gui.end_mode()
            if self._quit_on_end:
                # Quit automatically (for running sequential expts)
                os._exit(1)
예제 #24
0
class GPSMain(object):
    """ Main class to run algorithms and experiments. """
    def __init__(self, config, quit_on_end=False):
        """
        Initialize GPSMain
        Args:
            config: Hyperparameters for experiment
            quit_on_end: When true, quit automatically on completion
        """
        self._quit_on_end = quit_on_end
        self._hyperparams = config
        self._conditions = config['common']['conditions']
        if 'train_conditions' in config['common']:
            self._train_idx = config['common']['train_conditions']
            self._test_idx = config['common']['test_conditions']
        else:
            self._train_idx = range(self._conditions)
            config['common']['train_conditions'] = config['common'][
                'conditions']
            self._hyperparams = config
            self._test_idx = self._train_idx

        self._data_files_dir = config['common']['data_files_dir']

        self.agent = config['agent']['type'](config['agent'])
        if type(self.agent) == AgentSUPERball:
            self.agent.reset()
        self.data_logger = DataLogger()
        self.gui = GPSTrainingGUI(
            config['common']) if config['gui_on'] else None

        config['algorithm']['agent'] = self.agent
        self.algorithm = config['algorithm']['type'](config['algorithm'])

    def run(self, itr_load=None):
        """
        Run training by iteratively sampling and taking an iteration.
        Args:
            itr_load: If specified, loads algorithm state from that
                iteration, and resumes training at the next iteration.
        Returns: None
        """
        itr_start = self._initialize(itr_load)

        for itr in range(itr_start, self._hyperparams['iterations']):
            for i in range(self._hyperparams['num_samples']):
                for cond in self._train_idx:
                    self._take_sample(itr, cond, i)

            traj_sample_lists = [
                self.agent.get_samples(cond, -self._hyperparams['num_samples'])
                for cond in self._train_idx
            ]

            self._take_iteration(itr, traj_sample_lists)

            pol_sample_lists = None
            if ('verbose_policy_trials' in self._hyperparams
                    and self._hyperparams['verbose_policy_trials']):
                pol_sample_lists = self._take_policy_samples()
            self._log_data(itr, traj_sample_lists, pol_sample_lists)

            if 'save_controller' in self._hyperparams and self._hyperparams[
                    'save_controller']:
                self._save_superball_controllers()

        self._end()

    def test_policy(self, itr, N):
        """
        Take N policy samples of the algorithm state at iteration itr,
        for testing the policy to see how it is behaving.
        (Called directly from the command line --policy flag).
        Args:
            itr: the iteration from which to take policy samples
            N: the number of policy samples to take
        Returns: None
        """
        algorithm_file = self._data_files_dir + 'algorithm_itr_%02d.pkl' % itr
        self.algorithm = self.data_logger.unpickle(algorithm_file)
        if self.algorithm is None:
            print("Error: cannot find '%s.'" % algorithm_file)
            os._exit(1)
        traj_sample_lists = self.data_logger.unpickle(
            self._data_files_dir + ('traj_sample_itr_%02d.pkl' % itr))

        pol_sample_lists = self._take_policy_samples(N, test_policy=True)
        self.data_logger.pickle(
            self._data_files_dir + ('pol_sample_itr_%02d.pkl' % itr),
            copy.copy(pol_sample_lists))

        if self.gui:
            self.gui.update(itr, self.algorithm, self.agent, traj_sample_lists,
                            pol_sample_lists)
            self.gui.set_status_text(
                ('Took %d policy sample(s) from ' +
                 'algorithm state at iteration %d.\n' +
                 'Saved to: data_files/pol_sample_itr_%02d.pkl.\n') %
                (N, itr, itr))

    def _initialize(self, itr_load):
        """
        Initialize from the specified iteration.
        Args:
            itr_load: If specified, loads algorithm state from that
                iteration, and resumes training at the next iteration.
        Returns:
            itr_start: Iteration to start from.
        """
        if itr_load is None:
            if self.gui:
                self.gui.set_status_text('Press \'go\' to begin.')
            return 0
        else:
            algorithm_file = self._data_files_dir + 'algorithm_itr_%02d.pkl' % itr_load
            self.algorithm = self.data_logger.unpickle(algorithm_file)
            if self.algorithm is None:
                print("Error: cannot find '%s.'" % algorithm_file)
                os._exit(1)

            if self.gui:
                traj_sample_lists = self.data_logger.unpickle(
                    self._data_files_dir +
                    ('traj_sample_itr_%02d.pkl' % itr_load))
                if self.algorithm.cur[0].pol_info:
                    pol_sample_lists = self.data_logger.unpickle(
                        self._data_files_dir +
                        ('pol_sample_itr_%02d.pkl' % itr_load))
                else:
                    pol_sample_lists = None
                self.gui.set_status_text((
                    'Resuming training from algorithm state at iteration %d.\n'
                    + 'Press \'go\' to begin.') % itr_load)
            return itr_load + 1

    def _take_sample(self, itr, cond, i):
        """
        Collect a sample from the agent.
        Args:
            itr: Iteration number.
            cond: Condition number.
            i: Sample number.
        Returns: None
        """
        extra_args = {}
        if type(self.agent) == AgentSUPERball:
            extra_args = {
                'superball_parameters': {
                    'reset': ('reset' not in self._hyperparams['agent']
                              or self._hyperparams['agent']['reset'][cond]),
                    'relax': ('relax' in self._hyperparams['agent']
                              and self._hyperparams['agent']['relax'][cond]),
                    'bottom_face':
                    (None if 'bottom_faces' not in self._hyperparams['agent']
                     else self._hyperparams['agent']['bottom_faces'][cond]),
                    'start_motor_positions':
                    (None if 'start_motor_positions'
                     not in self._hyperparams['agent'] else
                     self._hyperparams['agent']['start_motor_positions'][cond]
                     ),
                    'motor_position_control_gain':
                    (None if 'motor_position_control_gain'
                     not in self._hyperparams['agent'] else
                     self._hyperparams['agent']['motor_position_control_gain']
                     [cond]),
                }
            }

        if self.algorithm._hyperparams['sample_on_policy'] \
                and (self.algorithm.iteration_count > 0
                     or ('sample_pol_first_itr' in self.algorithm._hyperparams
                         and self.algorithm._hyperparams['sample_pol_first_itr'])):
            pol = self.algorithm.policy_opt.policy
        else:
            pol = self.algorithm.cur[cond].traj_distr

        if self.gui:
            self.gui.set_image_overlays(cond)  # Must call for each new cond.
            redo = True
            while redo:
                while self.gui.mode in ('wait', 'request', 'process'):
                    if self.gui.mode in ('wait', 'process'):
                        time.sleep(0.01)
                        continue
                    # 'request' mode.
                    if self.gui.request == 'reset':
                        try:
                            self.agent.reset(cond)
                        except NotImplementedError:
                            self.gui.err_msg = 'Agent reset unimplemented.'
                    elif self.gui.request == 'fail':
                        self.gui.err_msg = 'Cannot fail before sampling.'
                    self.gui.process_mode()  # Complete request.

                self.gui.set_status_text(
                    'Sampling: iteration %d, condition %d, sample %d.' %
                    (itr, cond, i))
                self.agent.sample(
                    pol,
                    cond,
                    verbose=(i < self._hyperparams['verbose_trials']),
                    **extra_args)

                if self.gui.mode == 'request' and self.gui.request == 'fail':
                    redo = True
                    self.gui.process_mode()
                    self.agent.delete_last_sample(cond)
                else:
                    redo = False
        else:
            self.agent.sample(
                pol,
                cond,
                verbose=(i < self._hyperparams['verbose_trials']),
                **extra_args)

    def _take_iteration(self, itr, sample_lists):
        """
        Take an iteration of the algorithm.
        Args:
            itr: Iteration number.
        Returns: None
        """
        if self.gui:
            self.gui.set_status_text('Calculating.')
            self.gui.start_display_calculating()
        self.algorithm.iteration(sample_lists)
        if self.gui:
            self.gui.stop_display_calculating()

    def _take_policy_samples(self, N=None, test_policy=False):
        """
        Take samples from the policy to see how it's doing.
        Args:
            N  : number of policy samples to take per condition
        Returns: None
        """
        if not N:
            N = self._hyperparams['verbose_policy_trials']
        if self.gui:
            self.gui.set_status_text('Taking policy samples.')

        verbose = self._hyperparams['verbose_policy_trials']
        if self.gui:
            self.gui.set_status_text('Taking policy samples.')
        pol_samples = [[None] for _ in range(len(self._test_idx))]
        # Since this isn't noisy, just take one sample.
        for cond in range(len(self._test_idx)):
            extra_args = {}
            if type(self.agent) == AgentSUPERball:
                extra_args = {
                    'superball_parameters': {
                        'reset':
                        ('reset' not in self._hyperparams['agent']
                         or self._hyperparams['agent']['reset'][cond]),
                        'relax':
                        ('relax' in self._hyperparams['agent']
                         and self._hyperparams['agent']['relax'][cond]),
                        'bottom_face':
                        (None if 'bottom_faces'
                         not in self._hyperparams['agent'] else
                         self._hyperparams['agent']['bottom_faces'][cond]),
                        'horizon':
                        (None if (not test_policy) or 'policy_test_horizon'
                         not in self._hyperparams['agent'] else
                         self._hyperparams['agent']['policy_test_horizon']),
                        'start_motor_positions':
                        (None if 'start_motor_positions'
                         not in self._hyperparams['agent'] else self.
                         _hyperparams['agent']['start_motor_positions'][cond]),
                        'motor_position_control_gain':
                        (None if 'motor_position_control_gain'
                         not in self._hyperparams['agent'] else
                         self._hyperparams['agent']
                         ['motor_position_control_gain'][cond]),
                        'debug':
                        False,
                    }
                }
            pol_samples[cond][0] = self.agent.sample(
                self.algorithm.policy_opt.policy,
                self._test_idx[cond],
                verbose=verbose,
                save=False,
                noisy=False,
                **extra_args)
        return [SampleList(samples) for samples in pol_samples]

    def _log_data(self, itr, traj_sample_lists, pol_sample_lists=None):
        """
        Log data and algorithm, and update the GUI.
        Args:
            itr: Iteration number.
            traj_sample_lists: trajectory samples as SampleList object
            pol_sample_lists: policy samples as SampleList object
        Returns: None
        """
        if self.gui:
            self.gui.set_status_text('Logging data and updating GUI.')
            self.gui.update(itr, self.algorithm, self.agent, traj_sample_lists,
                            pol_sample_lists)
            self.gui.save_figure(self._data_files_dir +
                                 ('figure_itr_%02d.png' % itr))
        if 'no_sample_logging' in self._hyperparams['common']:
            return
        self.data_logger.pickle(
            self._data_files_dir + ('algorithm_itr_%02d.pkl' % itr),
            copy.copy(self.algorithm))
        self.data_logger.pickle(
            self._data_files_dir + ('traj_sample_itr_%02d.pkl' % itr),
            copy.copy(traj_sample_lists))
        if pol_sample_lists:
            self.data_logger.pickle(
                self._data_files_dir + ('pol_sample_itr_%02d.pkl' % itr),
                copy.copy(pol_sample_lists))

    def _end(self):
        """ Finish running and exit. """
        if self.gui:
            self.gui.set_status_text('Training complete.')
            self.gui.end_mode()

        if self._quit_on_end:
            # Quit automatically (for running sequential expts)
            os._exit(1)

    def _save_superball_controllers(self):
        import cPickle as pickle
        import numpy as np
        from scipy import io
        try:
            controllers = io.loadmat('init_controllers.mat')
        except IOError:
            controllers = {}
        for m in range(self._conditions):
            t = self.algorithm.cur[m].traj_distr
            if 'save' in self._hyperparams['algorithm']['init_traj_distr']:
                s = self._hyperparams['algorithm']['init_traj_distr']['save'][
                    m]
            else:
                s = str(m)
            controllers.update({
                ('K' + s): t.K,
                ('k' + s): t.k,
                ('PS' + s): t.pol_covar,
                ('cPS' + s): t.chol_pol_covar,
                ('iPS' + s): t.inv_pol_covar
            })
        io.savemat(
            self._hyperparams['common']['experiment_dir'] +
            'init_controllers.mat', controllers)
        io.savemat('init_controllers.mat', controllers)
예제 #25
0
파일: gps_main.py 프로젝트: Ivehui/gps
class GPSMain(object):
    """ Main class to run algorithms and experiments. """
    def __init__(self, config):
        self._hyperparams = config
        self._conditions = config['common']['conditions']
        if 'train_conditions' in config['common']:
            self._train_idx = config['common']['train_conditions']
            self._test_idx = config['common']['test_conditions']
        else:
            self._train_idx = range(self._conditions)
            config['common']['train_conditions'] = config['common']['conditions']
            self._hyperparams=config
            self._test_idx = self._train_idx

        self._data_files_dir = config['common']['data_files_dir']

        self.agent = config['agent']['type'](config['agent'])
        self.data_logger = DataLogger()
        self.gui = GPSTrainingGUI(config['common']) if config['gui_on'] else None

        config['algorithm']['agent'] = self.agent
        self.algorithm = config['algorithm']['type'](config['algorithm'])

    def run(self, itr_load=None):
        """
        Run training by iteratively sampling and taking an iteration.
        Args:
            itr_load: If specified, loads algorithm state from that
                iteration, and resumes training at the next iteration.
        Returns: None
        """
        itr_start = self._initialize(itr_load)

        for itr in range(itr_start, self._hyperparams['iterations']):
            for cond in self._train_idx:
                for i in range(self._hyperparams['num_samples']):
                    self._take_sample(itr, cond, i)

            traj_sample_lists = [
                self.agent.get_samples(cond, -self._hyperparams['num_samples'])
                for cond in self._train_idx
            ]
            self._take_iteration(itr, traj_sample_lists)
            pol_sample_lists = self._take_policy_samples()
            self._log_data(itr, traj_sample_lists, pol_sample_lists)

        self._end()

    def test_policy(self, itr, N):
        """
        Take N policy samples of the algorithm state at iteration itr,
        for testing the policy to see how it is behaving.
        (Called directly from the command line --policy flag).
        Args:
            itr: the iteration from which to take policy samples
            N: the number of policy samples to take
        Returns: None
        """
        algorithm_file = self._data_files_dir + 'algorithm_itr_%02d.pkl' % itr
        self.algorithm = self.data_logger.unpickle(algorithm_file)
        if self.algorithm is None:
            print("Error: cannot find '%s.'" % algorithm_file)
            os._exit(1) # called instead of sys.exit(), since t
        traj_sample_lists = self.data_logger.unpickle(self._data_files_dir +
            ('traj_sample_itr_%02d.pkl' % itr))

        pol_sample_lists = self._take_policy_samples(N)
        self.data_logger.pickle(
            self._data_files_dir + ('pol_sample_itr_%02d.pkl' % itr),
            copy.copy(pol_sample_lists)
        )

        if self.gui:
            self.gui.update(itr, self.algorithm, self.agent,
                traj_sample_lists, pol_sample_lists)
            self.gui.set_status_text(('Took %d policy sample(s) from ' +
                'algorithm state at iteration %d.\n' +
                'Saved to: data_files/pol_sample_itr_%02d.pkl.\n') % (N, itr, itr))

    def _initialize(self, itr_load):
        """
        Initialize from the specified iteration.
        Args:
            itr_load: If specified, loads algorithm state from that
                iteration, and resumes training at the next iteration.
        Returns:
            itr_start: Iteration to start from.
        """
        if itr_load is None:
            if self.gui:
                self.gui.set_status_text('Press \'go\' to begin.')
            return 0
        else:
            algorithm_file = self._data_files_dir + 'algorithm_i_%02d.pkl' % itr_load
            self.algorithm = self.data_logger.unpickle(algorithm_file)
            if self.algorithm is None:
                print("Error: cannot find '%s.'" % algorithm_file)
                os._exit(1) # called instead of sys.exit(), since this is in a thread
                
            if self.gui:
                traj_sample_lists = self.data_logger.unpickle(self._data_files_dir +
                    ('traj_sample_itr_%02d.pkl' % itr_load))
                pol_sample_lists = self.data_logger.unpickle(self._data_files_dir +
                    ('pol_sample_itr_%02d.pkl' % itr_load))
                self.gui.update(itr_load, self.algorithm, self.agent,
                    traj_sample_lists, pol_sample_lists)
                self.gui.set_status_text(
                    ('Resuming training from algorithm state at iteration %d.\n' +
                    'Press \'go\' to begin.') % itr_load)
            return itr_load + 1

    def _take_sample(self, itr, cond, i):
        """
        Collect a sample from the agent.
        Args:
            itr: Iteration number.
            cond: Condition number.
            i: Sample number.
        Returns: None
        """
        pol = self.algorithm.cur[cond].traj_distr
        if self.gui:
            self.gui.set_image_overlays(cond)   # Must call for each new cond.
            redo = True
            while redo:
                while self.gui.mode in ('wait', 'request', 'process'):
                    if self.gui.mode in ('wait', 'process'):
                        time.sleep(0.01)
                        continue
                    # 'request' mode.
                    if self.gui.request == 'reset':
                        try:
                            self.agent.reset(cond)
                        except NotImplementedError:
                            self.gui.err_msg = 'Agent reset unimplemented.'
                    elif self.gui.request == 'fail':
                        self.gui.err_msg = 'Cannot fail before sampling.'
                    self.gui.process_mode()  # Complete request.

                self.gui.set_status_text(
                    'Sampling: iteration %d, condition %d, sample %d.' %
                    (itr, cond, i)
                )
                self.agent.sample(
                    pol, cond,
                    verbose=(i < self._hyperparams['verbose_trials'])
                )

                if self.gui.mode == 'request' and self.gui.request == 'fail':
                    redo = True
                    self.gui.process_mode()
                    self.agent.delete_last_sample(cond)
                else:
                    redo = False
        else:
            self.agent.sample(
                pol, cond,
                verbose=(i < self._hyperparams['verbose_trials'])
            )

    def _take_iteration(self, itr, sample_lists):
        """
        Take an iteration of the algorithm.
        Args:
            itr: Iteration number.
        Returns: None
        """
        if self.gui:
            self.gui.set_status_text('Calculating.')
            self.gui.start_display_calculating()
        self.algorithm.iteration(sample_lists)
        if self.gui:
            self.gui.stop_display_calculating()

    def _take_policy_samples(self, N=None):
        """
        Take samples from the policy to see how it's doing.
        Args:
            N  : number of policy samples to take per condition
        Returns: None
        """
        if 'verbose_policy_trials' not in self._hyperparams:
            return None
        if not N:
            N = self._hyperparams['verbose_policy_trials']
        if self.gui:
            self.gui.set_status_text('Taking policy samples.')
        pol_samples = [[None for _ in range(N)] for _ in range(self._conditions)]
        for cond in range(len(self._test_idx)):
            for i in range(N):
                pol_samples[cond][i] = self.agent.sample(
                    self.algorithm.policy_opt.policy, self._test_idx[cond],
                    verbose=True, save=False)
        return [SampleList(samples) for samples in pol_samples]

    def _log_data(self, itr, traj_sample_lists, pol_sample_lists=None):
        """
        Log data and algorithm, and update the GUI.
        Args:
            itr: Iteration number.
            traj_sample_lists: trajectory samples as SampleList object
            pol_sample_lists: policy samples as SampleList object
        Returns: None
        """
        if self.gui:
            self.gui.set_status_text('Logging data and updating GUI.')
            self.gui.update(itr, self.algorithm, self.agent,
                traj_sample_lists, pol_sample_lists)
            self.gui.save_figure(
                self._data_files_dir + ('figure_itr_%02d.png' % itr)
            )
        if 'no_sample_logging' in self._hyperparams['common']:
            return
        self.data_logger.pickle(
            self._data_files_dir + ('algorithm_itr_%02d.pkl' % itr),
            copy.copy(self.algorithm)
        )
        self.data_logger.pickle(
            self._data_files_dir + ('traj_sample_itr_%02d.pkl' % itr),
            copy.copy(traj_sample_lists)
        )
        if pol_sample_lists:
            self.data_logger.pickle(
                self._data_files_dir + ('pol_sample_itr_%02d.pkl' % itr),
                copy.copy(pol_sample_lists)
            )

    def _end(self):
        """ Finish running and exit. """
        if self.gui:
            self.gui.set_status_text('Training complete.')
            self.gui.end_mode()
예제 #26
0
class GPSMain(object):
    """ Main class to run algorithms and experiments. """
    def __init__(self, config, quit_on_end=False):
        """
        Initialize GPSMain
        Args:
            config: Hyperparameters for experiment
            quit_on_end: When true, quit automatically on completion
        """
        self._quit_on_end = quit_on_end
        self._hyperparams = config
        self._conditions = config['common']['conditions']
        #self._condition = 1
        if 'train_conditions' in config['common']:
            #False
            self._train_idx = config['common']['train_conditions']
            self._test_idx = config['common']['test_conditions']
        else:
            self._train_idx = range(self._conditions)
            config['common']['train_conditions'] = config['common'][
                'conditions']
            #create a new key in the dictionary common and assign the value 1
            self._hyperparams = config
            #reinitiallizing the hyperparameters because the config was changed
            self._test_idx = self._train_idx
            #getting hte train index again
        self._data_files_dir = config['common']['data_files_dir']
        #getting the data file path from which is stored in the common dic
        self.agent = config['agent']['type'](config['agent'])
        #here it creat the object from the agent directory
        #print(self.agent,'self.agent')
        self.data_logger = DataLogger()
        #here the gui files leads to the
        self.gui = GPSTrainingGUI(
            config['common']) if config['gui_on'] else None
        #again with they change the config file now adding object to the dic
        config['algorithm']['agent'] = self.agent
        self.algorithm = config['algorithm']['type'](config['algorithm'])
        #print(config['algorithm']['type'](config['algorithm']),'self.algo')
        # gps.algorithm.algorithm_traj_opt.AlgorithmTrajOpt is the algorithm which is used

    def run(self, itr_load=None):
        """
        Run training by iteratively sampling and taking an iteration.
        Args:
            itr_load: If specified, loads algorithm state from that
                iteration, and resumes training at the next iteration.
        Returns: None
        """
        #this is the callable function which is used in the main.
        try:
            #self._initialize is the function which opens the GUI and the itr_stat
            #this itr_start is the provided by the user and is reaassigned into
            #itr_start
            itr_start = self._initialize(itr_load)
            #print(itr_start,'iteration start',self._initialize,'this is to initialize some')
            for itr in range(itr_start, self._hyperparams['iterations']):
                #basically the iteration starts from the iteration given by run
                #by the user and the ends at iteration in the config of the
                #hyperparameters file in this case 5
                for cond in self._train_idx:
                    # this is the conditions offered in the training index in
                    # case point  = 0
                    for i in range(self._hyperparams['num_samples']):
                        #again this is 5
                        print('wow wow wow wow wow wow wow wow')
                        self._take_sample(itr, cond, i)

                traj_sample_lists = [
                    #this function in the agent super class, this function instantiates the sample.py file
                    self.agent.get_samples(cond,
                                           -self._hyperparams['num_samples'])
                    for cond in self._train_idx
                ]
                print(traj_sample_lists, 'Ed-sheerens')
                # Clear agent samples.
                #this function is in the agent superclass.
                self.agent.clear_samples()

                self._take_iteration(itr, traj_sample_lists)
                pol_sample_lists = self._take_policy_samples()
                self._log_data(itr, traj_sample_lists, pol_sample_lists)
        except Exception as e:
            traceback.print_exception(*sys.exc_info())
        finally:
            self._end()

    def test_policy(self, itr, N):
        """
        Take N policy samples of the algorithm state at iteration itr,
        for testing the policy to see how it is behaving.
        (Called directly from the command line --policy flag).
        Args:
            itr: the iteration from which to take policy samples
            N: the number of policy samples to take
        Returns: None
        """
        algorithm_file = self._data_files_dir + 'algorithm_itr_%02d.pkl' % itr
        self.algorithm = self.data_logger.unpickle(algorithm_file)
        if self.algorithm is None:
            print("Error: cannot find '%s.'" % algorithm_file)
            os._exit(1)  # called instead of sys.exit(), since t
        traj_sample_lists = self.data_logger.unpickle(
            self._data_files_dir + ('traj_sample_itr_%02d.pkl' % itr))

        pol_sample_lists = self._take_policy_samples(N)
        self.data_logger.pickle(
            self._data_files_dir + ('pol_sample_itr_%02d.pkl' % itr),
            copy.copy(pol_sample_lists))

        if self.gui:
            self.gui.update(itr, self.algorithm, self.agent, traj_sample_lists,
                            pol_sample_lists)
            self.gui.set_status_text(
                ('Took %d policy sample(s) from ' +
                 'algorithm state at iteration %d.\n' +
                 'Saved to: data_files/pol_sample_itr_%02d.pkl.\n') %
                (N, itr, itr))

    def _initialize(self, itr_load):
        """
        Initialize from the specified iteration.
        Args:
            itr_load: If specified, loads algorithm state from that
                iteration, and resumes training at the next iteration.
        Returns:
            itr_start: Iteration to start from.
        """
        if itr_load is None:
            if self.gui:
                self.gui.set_status_text('Press \'go\' to begin.')
            return 0
        else:
            algorithm_file = self._data_files_dir + 'algorithm_itr_%02d.pkl' % itr_load
            self.algorithm = self.data_logger.unpickle(algorithm_file)
            if self.algorithm is None:
                print("Error: cannot find '%s.'" % algorithm_file)
                os._exit(
                    1
                )  # called instead of sys.exit(), since this is in a thread

            if self.gui:
                traj_sample_lists = self.data_logger.unpickle(
                    self._data_files_dir +
                    ('traj_sample_itr_%02d.pkl' % itr_load))
                if self.algorithm.cur[0].pol_info:
                    pol_sample_lists = self.data_logger.unpickle(
                        self._data_files_dir +
                        ('pol_sample_itr_%02d.pkl' % itr_load))
                else:
                    pol_sample_lists = None
                self.gui.set_status_text((
                    'Resuming training from algorithm state at iteration %d.\n'
                    + 'Press \'go\' to begin.') % itr_load)
            return itr_load + 1

    def _take_sample(self, itr, cond, i):
        """
        Collect a sample from the agent.
        Args:
            itr: Iteration number.
            cond: Condition number.
            i: Sample number.
        Returns: None
        """
        if self.algorithm._hyperparams['sample_on_policy'] \
                and self.algorithm.iteration_count > 0:
            print(self.algorithm.iteration_count)
            pol = self.algorithm.policy_opt.policy
        else:

            #print(self.algorithm.iteration_count)
            pol = self.algorithm.cur[cond].traj_distr
            print(self.algorithm.cur[cond].traj_distr,
                  'what is the this dis_traj')
            #print(self.algorithm.cur,'this is the pol',cond,'cond')
        if self.gui:
            self.gui.set_image_overlays(cond)  # Must call for each new cond.
            redo = True
            while redo:
                while self.gui.mode in ('wait', 'request', 'process'):
                    if self.gui.mode in ('wait', 'process'):
                        time.sleep(0.01)
                        continue
                    # 'request' mode.
                    if self.gui.request == 'reset':
                        try:
                            self.agent.reset(cond)
                        except NotImplementedError:
                            self.gui.err_msg = 'Agent reset unimplemented.'
                    elif self.gui.request == 'fail':
                        self.gui.err_msg = 'Cannot fail before sampling.'
                    self.gui.process_mode()  # Complete request.

                self.gui.set_status_text(
                    'Sampling: iteration %d, condition %d, sample %d.' %
                    (itr, cond, i))
                #sampling is done in the agent node, here agent is agent_box2D and agent is the super class.
                self.agent.sample(
                    pol,
                    cond,
                    verbose=(i < self._hyperparams['verbose_trials']))

                if self.gui.mode == 'request' and self.gui.request == 'fail':
                    redo = True
                    self.gui.process_mode()
                    self.agent.delete_last_sample(cond)
                else:
                    redo = False
        else:
            self.agent.sample(
                pol, cond, verbose=(i < self._hyperparams['verbose_trials']))

    def _take_iteration(self, itr, sample_lists):
        """
        Take an iteration of the algorithm.
        Args:
            itr: Iteration number.
        Returns: None
        """
        if self.gui:
            self.gui.set_status_text('Calculating.')
            self.gui.start_display_calculating()
        # this iteration is in the metaclass or parent class algorithm in the gps directory, here the sample_list is the data
        #data that is collected by runnning the simulation for 5 steps.
        self.algorithm.iteration(sample_lists)
        if self.gui:
            self.gui.stop_display_calculating()

    def _take_policy_samples(self, N=None):
        """
        Take samples from the policy to see how it's doing.
        Args:
            N  : number of policy samples to take per condition
        Returns: None
        """
        if 'verbose_policy_trials' not in self._hyperparams:
            # AlgorithmTrajOpt
            return None
        verbose = self._hyperparams['verbose_policy_trials']
        if self.gui:
            self.gui.set_status_text('Taking policy samples.')
        pol_samples = [[None] for _ in range(len(self._test_idx))]
        # Since this isn't noisy, just take one sample.
        # TODO: Make this noisy? Add hyperparam?
        # TODO: Take at all conditions for GUI?
        for cond in range(len(self._test_idx)):
            pol_samples[cond][0] = self.agent.sample(
                self.algorithm.policy_opt.policy,
                self._test_idx[cond],
                verbose=verbose,
                save=False,
                noisy=False)
        return [SampleList(samples) for samples in pol_samples]

    def _log_data(self, itr, traj_sample_lists, pol_sample_lists=None):
        """
        Log data and algorithm, and update the GUI.
        Args:
            itr: Iteration number.
            traj_sample_lists: trajectory samples as SampleList object
            pol_sample_lists: policy samples as SampleList object
        Returns: None
        """
        if self.gui:
            self.gui.set_status_text('Logging data and updating GUI.')
            self.gui.update(itr, self.algorithm, self.agent, traj_sample_lists,
                            pol_sample_lists)
            self.gui.save_figure(self._data_files_dir +
                                 ('figure_itr_%02d.png' % itr))
        if 'no_sample_logging' in self._hyperparams['common']:
            return
        self.data_logger.pickle(
            self._data_files_dir + ('algorithm_itr_%02d.pkl' % itr),
            copy.copy(self.algorithm))
        self.data_logger.pickle(
            self._data_files_dir + ('traj_sample_itr_%02d.pkl' % itr),
            copy.copy(traj_sample_lists))
        if pol_sample_lists:
            self.data_logger.pickle(
                self._data_files_dir + ('pol_sample_itr_%02d.pkl' % itr),
                copy.copy(pol_sample_lists))

    def _end(self):
        """ Finish running and exit. """
        if self.gui:
            self.gui.set_status_text('Training complete.')
            self.gui.end_mode()
            if self._quit_on_end:
                # Quit automatically (for running sequential expts)
                os._exit(1)