def setup_agent(T=100):
    defaults['agent']['T'] = T
    #defaults['agent']['state_include'] = [JOINT_ANGLES, JOINT_VELOCITIES]
    #sample_data = SampleData(defaults['sample_data'], defaults['common'], False)
    agent = AgentROS(defaults['agent'])
    r = rospy.Rate(1)
    r.sleep()
    return agent
def main():
    """ Main function to be run. """
    parser = argparse.ArgumentParser(
        description='Run the Guided Policy Search algorithm.')

    parser.add_argument('--experiment',
                        type=str,
                        help='experiment',
                        default='box2d_arm_example')
    parser.add_argument('--new',
                        default=False,
                        type=bool,
                        help='create new experiment')
    parser.add_argument('--targetsetup',
                        default=False,
                        type=bool,
                        help='run target setup')
    parser.add_argument('--resume',
                        default=None,
                        type=int,
                        help='resume training from iter N')
    parser.add_argument('--policy',
                        default=None,
                        type=int,
                        help='take N policy samples (for BADMM/MDGPS only)')
    parser.add_argument('--silent',
                        default=False,
                        type=bool,
                        help='silent debug print outs')
    parser.add_argument('--quit',
                        default=False,
                        type=bool,
                        help='quit GUI automatically when finished')

    args = parser.parse_args()
    """ define main parameters """
    exp_name = args.experiment
    resume_training_itr = args.resume
    test_policy_N = args.policy

    from gps import __file__ as gps_filepath

    gps_filepath = os.path.abspath(gps_filepath)
    gps_dir = '/'.join(str.split(gps_filepath, '/')[:-3]) + '/'
    exp_dir = gps_dir + 'experiments/' + exp_name + '/'
    print('gps_dir', gps_dir)
    print('exp_dir', exp_dir)

    hyperparams_file = exp_dir + 'hyperparams.py'

    if args.silent:
        logging.basicConfig(format='%(levelname)s:%(message)s',
                            level=logging.INFO)
    else:
        logging.basicConfig(format='%(levelname)s:%(message)s',
                            level=logging.DEBUG)

    if args.new:
        from shutil import copy

        if os.path.exists(exp_dir):
            sys.exit("Experiment '%s' already exists.\nPlease remove '%s'." %
                     (exp_name, exp_dir))
        os.makedirs(exp_dir)

        prev_exp_file = '.previous_experiment'
        prev_exp_dir = None
        try:
            with open(prev_exp_file, 'r') as f:
                prev_exp_dir = f.readline()
            copy(prev_exp_dir + 'hyperparams.py', exp_dir)
            if os.path.exists(prev_exp_dir + 'targets.npz'):
                copy(prev_exp_dir + 'targets.npz', exp_dir)
        except IOError as e:
            with open(hyperparams_file, 'w') as f:
                f.write(
                    '# To get started, copy over hyperparams from another experiment.\n'
                    +
                    '# Visit rll.berkeley.edu/gps/hyperparams.html for documentation.'
                )
        with open(prev_exp_file, 'w') as f:
            f.write(exp_dir)

        exit_msg = ("Experiment '%s' created.\nhyperparams file: '%s'" %
                    (exp_name, hyperparams_file))
        if prev_exp_dir and os.path.exists(prev_exp_dir):
            exit_msg += "\ncopied from     : '%shyperparams.py'" % prev_exp_dir
        sys.exit(exit_msg)

    if not os.path.exists(hyperparams_file):
        sys.exit("Experiment '%s' does not exist.\nDid you create '%s'?" %
                 (exp_name, hyperparams_file))

    hyperparams = imp.load_source('hyperparams', hyperparams_file)
    print('hyperparams', hyperparams)

    if args.targetsetup:
        """" ================================= testing  ros  =============================="""
        try:
            import matplotlib.pyplot as plt
            from gps.agent.ros.agent_ros import AgentROS
            from gps.gui.target_setup_gui import TargetSetupGUI

            agent = AgentROS(hyperparams.config['agent'])
            TargetSetupGUI(hyperparams.config['common'], agent)

        except ImportError:
            sys.exit('ROS required for target setup.')

    elif test_policy_N:

        print(
            """" ================================= testing  process =============================="""
        )
        import random
        import numpy as np

        seed = hyperparams.config.get('random_seed', 0)
        random.seed(seed)
        np.random.seed(seed)

        data_files_dir = exp_dir + 'data_files/'
        data_filenames = os.listdir(data_files_dir)
        algorithm_prefix = 'algorithm_itr_'
        algorithm_filenames = [
            f for f in data_filenames if f.startswith(algorithm_prefix)
        ]
        current_algorithm = sorted(algorithm_filenames, reverse=True)[0]
        current_itr = int(
            current_algorithm[len(algorithm_prefix):len(algorithm_prefix) + 2])

        gps = GPSMain(hyperparams.config)
        gps.test_policy(itr=current_itr, N=test_policy_N)

    else:
        import random
        import numpy as np

        seed = hyperparams.config.get('random_seed', 0)
        random.seed(seed)
        np.random.seed(seed)

        print(
            """" ================================ training  process =============================="""
        )
        gps = GPSMain(hyperparams.config, args.quit)

        gps.run(itr_load=resume_training_itr)
Пример #3
0
def main():
    """ Main function to be run. """
    parser = argparse.ArgumentParser(
        description='Run the Guided Policy Search algorithm.')
    parser.add_argument('experiment', type=str, help='experiment name')
    parser.add_argument('-n',
                        '--new',
                        action='store_true',
                        help='create new experiment')
    parser.add_argument('-t',
                        '--targetsetup',
                        action='store_true',
                        help='run target setup')
    parser.add_argument('-r',
                        '--resume',
                        metavar='N',
                        type=int,
                        help='resume training from iter N')
    parser.add_argument('-p',
                        '--policy',
                        metavar='N',
                        type=int,
                        help='take N policy samples (for BADMM/MDGPS only)')
    parser.add_argument('-s',
                        '--silent',
                        action='store_true',
                        help='silent debug print outs')
    parser.add_argument('-q',
                        '--quit',
                        action='store_true',
                        help='quit GUI automatically when finished')
    parser.add_argument('-c',
                        '--condition',
                        metavar='N',
                        type=int,
                        help='consider N position')
    parser.add_argument('-m',
                        '--num',
                        metavar='N',
                        type=int,
                        help='test\' N nums of experiment')
    parser.add_argument('-exper',
                        '--exper',
                        metavar='N',
                        type=int,
                        help='time of test experiment')
    parser.add_argument('-set',
                        '--set_cond',
                        metavar='N',
                        type=int,
                        help='train on special position setting')
    parser.add_argument('-algi',
                        '--alg_itr',
                        metavar='N',
                        type=int,
                        help='control the time of train NN')

    args = parser.parse_args()

    exp_name = args.experiment
    resume_training_itr = args.resume
    test_policy_N = args.policy

    from gps import __file__ as gps_filepath
    gps_filepath = os.path.abspath(gps_filepath)
    gps_dir = '/'.join(str.split(gps_filepath, '/')[:-3]) + '/'
    exp_dir = gps_dir + 'experiments/' + exp_name + '/'
    hyperparams_file = exp_dir + 'hyperparams.py'

    if args.silent:
        logging.basicConfig(format='%(levelname)s:%(message)s',
                            level=logging.INFO)
    else:
        logging.basicConfig(format='%(levelname)s:%(message)s',
                            level=logging.DEBUG)

    if args.new:
        from shutil import copy

        if os.path.exists(exp_dir):
            sys.exit("Experiment '%s' already exists.\nPlease remove '%s'." %
                     (exp_name, exp_dir))
        os.makedirs(exp_dir)

        prev_exp_file = '.previous_experiment'
        prev_exp_dir = None
        try:
            with open(prev_exp_file, 'r') as f:
                prev_exp_dir = f.readline()
            copy(prev_exp_dir + 'hyperparams.py', exp_dir)
            if os.path.exists(prev_exp_dir + 'targets.npz'):
                copy(prev_exp_dir + 'targets.npz', exp_dir)
        except IOError as e:
            with open(hyperparams_file, 'w') as f:
                f.write(
                    '# To get started, copy over hyperparams from another experiment.\n'
                    +
                    '# Visit rll.berkeley.edu/gps/hyperparams.html for documentation.'
                )
        with open(prev_exp_file, 'w') as f:
            f.write(exp_dir)

        exit_msg = ("Experiment '%s' created.\nhyperparams file: '%s'" %
                    (exp_name, hyperparams_file))
        if prev_exp_dir and os.path.exists(prev_exp_dir):
            exit_msg += "\ncopied from     : '%shyperparams.py'" % prev_exp_dir
        sys.exit(exit_msg)

    if not os.path.exists(hyperparams_file):
        sys.exit("Experiment '%s' does not exist.\nDid you create '%s'?" %
                 (exp_name, hyperparams_file))

    hyperparams = imp.load_source('hyperparams', hyperparams_file)
    if args.targetsetup:
        try:
            import matplotlib.pyplot as plt
            from gps.agent.ros.agent_ros import AgentROS
            from gps.gui.target_setup_gui import TargetSetupGUI

            agent = AgentROS(hyperparams.config['agent'])
            TargetSetupGUI(hyperparams.config['common'], agent)

            plt.ioff()
            plt.show()
        except ImportError:
            sys.exit('ROS required for target setup.')
    elif test_policy_N:
        import random
        import numpy as np
        import matplotlib.pyplot as plt

        seed = hyperparams.config.get('random_seed', 0)
        random.seed(seed)
        np.random.seed(seed)

        data_files_dir = exp_dir + 'data_files/'
        data_filenames = os.listdir(data_files_dir)
        algorithm_prefix = 'algorithm_itr_'
        algorithm_filenames = [
            f for f in data_filenames if f.startswith(algorithm_prefix)
        ]
        current_algorithm = sorted(algorithm_filenames, reverse=True)[0]
        current_itr = int(
            current_algorithm[len(algorithm_prefix):len(algorithm_prefix) + 2])

        gps = GPSMain(hyperparams.config)
        if hyperparams.config['gui_on']:
            test_policy = threading.Thread(target=lambda: gps.test_policy(
                itr=current_itr, N=test_policy_N))
            test_policy.daemon = True
            test_policy.start()

            plt.ioff()
            plt.show()
        else:
            gps.test_policy(itr=current_itr, N=test_policy_N)
    else:
        if args.condition:
            """ if specify the N training position"""
            num_position = args.condition
            data_logger = DataLogger()
            positions = data_logger.unpickle('./position/train_position.pkl')
            # positions = data_logger.unpickle('./position/suc_train_position.pkl')
            hyperparams.agent['conditions'] = num_position
            hyperparams.common['conditions'] = num_position
            hyperparams.algorithm['conditions'] = num_position
            pos_body_offset = list()
            for i in range(num_position):
                pos_body_offset.append(positions[i])
            hyperparams.agent['pos_body_offset'] = pos_body_offset

        import random
        import numpy as np
        import matplotlib.pyplot as plt

        seed = hyperparams.config.get('random_seed', 0)
        random.seed(seed)
        np.random.seed(seed)

        # set the time of training NN
        if args.alg_itr:
            hyperparams.config['iterations'] = args.alg_itr
        """
        set extend setting
        """
        data_logger = DataLogger()
        train_position = data_logger.unpickle(
            './position/all_train_position.pkl')
        hyperparams.agent['pos_body_offset'] = list(train_position)
        hyperparams.agent['conditions'] = len(train_position)
        hyperparams.common['conditions'] = len(train_position)
        hyperparams.algorithm['conditions'] = len(train_position)

        gps = GPSMain(hyperparams.config, args.quit)
        if hyperparams.config['gui_on']:
            run_gps = threading.Thread(
                target=lambda: gps.run(itr_load=resume_training_itr))
            run_gps.daemon = True
            run_gps.start()

            plt.ioff()
            plt.show()
        else:
            costs, mean_cost, position_suc_count, all_distance = gps.run(
                args.num,
                exper_condition=args.set_cond,
                itr_load=resume_training_itr)
            # gps.data_logger.pickle('./position/%d/experiment_%d/md_all_distance.pkl'
            #                        % (args.num, args.exper), all_distance)
            gps.data_logger.pickle('./position/md_all_distance.pkl',
                                   all_distance)
            gps.data_logger.pickle('./position/md_all_cost.pkl', costs)
            """
Пример #4
0
def main():

    # INPUT ARGUMENTS #

    parser = argparse.ArgumentParser(
        description='Run the Guided Policy Search algorithm.')
    parser.add_argument('-e',
                        '--experiment',
                        type=str,
                        default='box2d_arm_example',
                        help='experiment name')
    parser.add_argument('-n',
                        '--new',
                        action='store_true',
                        help='create new experiment')
    parser.add_argument('-t',
                        '--targetsetup',
                        action='store_true',
                        help='run target setup')
    parser.add_argument('-r',
                        '--resume',
                        metavar='N',
                        type=int,
                        help='resume training from iter N')
    parser.add_argument('-p',
                        '--policy',
                        metavar='N',
                        type=int,
                        help='take N policy samples (for BADMM/MDGPS only)')
    parser.add_argument('-s',
                        '--silent',
                        action='store_true',
                        help='silent debug print outs')
    parser.add_argument('-q',
                        '--quit',
                        action='store_true',
                        help='quit GUI automatically when finished')
    args = parser.parse_args()

    # INPUT VARIABLES #

    exp_name = args.experiment  # experiment name
    resume_training_itr = args.resume  # iteration from which to resume training
    test_policy_N = args.policy  # number of policy samples to take

    # FILE-PATHS #

    from gps import __file__ as gps_filepath  # set 'gps_filepath' as root gps filepath
    gps_filepath = os.path.abspath(gps_filepath)  # reformat as absolute path
    gps_dir = '/'.join(str.split(
        gps_filepath, '/')[:-3]) + '/'  # remove 'gps' part, for root directory
    exp_dir = gps_dir + 'experiments/' + exp_name + '/'  # create experiment directory
    hyperparams_file = __file__[:-7] + 'hyperparams.py'  # complete path to hyperparameter file

    # LOGGING OPTION #

    if args.silent:
        logging.basicConfig(format='%(levelname)s:%(message)s',
                            level=logging.INFO)
    else:
        logging.basicConfig(format='%(levelname)s:%(message)s',
                            level=logging.DEBUG)

    if args.new:  # if new experiment desired
        from shutil import copy  # import file copy

        if os.path.exists(exp_dir):  # if already exists
            sys.exit("Experiment '%s' already exists.\nPlease remove '%s'." %
                     (exp_name, exp_dir))  # exit from python
        os.makedirs(exp_dir)  # else mkdir

        prev_exp_file = '.previous_experiment'  # hidden file in python gps directory, IF by previous run
        prev_exp_dir = None
        try:  # attempt following code
            with open(prev_exp_file, 'r') as f:
                prev_exp_dir = f.readline(
                )  # read previous experiment directory from hidden file
            copy(prev_exp_dir + 'hyperparams.py',
                 exp_dir)  # copy over hyperparameters from previous exp run
            if os.path.exists(
                    prev_exp_dir +
                    'targets.npz'):  # if target numpy array file exists
                copy(prev_exp_dir + 'targets.npz',
                     exp_dir)  # copy to new experiment directory
        except IOError as e:  # throw program terminating exception, unless IOError
            with open(hyperparams_file, 'w') as f:
                f.write(
                    '# To get started, copy over hyperparams from another experiment.\n'
                    +
                    '# Visit rll.berkeley.edu/gps/hyperparams.html for documentation.'
                )
                # if hyperparams were not copied over, instruct user on how to get started
        with open(prev_exp_file, 'w') as f:
            f.write(
                exp_dir
            )  # regardless of whether existed before, write new prev_exp hidden file

        exit_msg = ("Experiment '%s' created.\nhyperparams file: '%s'" %
                    (exp_name, hyperparams_file))  # base output message
        if prev_exp_dir and os.path.exists(prev_exp_dir):
            exit_msg += "\ncopied from     : '%shyperparams.py'" % prev_exp_dir
        sys.exit(exit_msg)  # if hyperparam file copied, also state where from
        # Finally, exit process, new experiment has been created, can now run again without '-n' argument

    if not os.path.exists(hyperparams_file):
        sys.exit("Experiment '%s' does not exist.\nDid you create '%s'?" %
                 (exp_name, hyperparams_file)
                 )  # if no hyperparams file, prompt to create one, and exit

    hyperparams = imp.load_source(
        'hyperparams',
        hyperparams_file)  # import hyperparams from hyperparam file

    if args.targetsetup:  # if target setup GUI option selected (for ROS only)
        try:
            import matplotlib.pyplot as plt
            from gps.agent.ros.agent_ros import AgentROS
            from gps.gui.target_setup_gui import TargetSetupGUI

            agent = AgentROS(hyperparams.config['agent'])
            TargetSetupGUI(hyperparams.config['common'], agent)

            plt.ioff()
            plt.show()
        except ImportError:
            sys.exit('ROS required for target setup.')
    elif test_policy_N:  # if testing current policy, how many policy samples to take at given iteration
        import random
        import numpy as np
        import matplotlib.pyplot as plt

        seed = hyperparams.config.get(
            'random_seed', 0)  # retrieve random_seed value from hyperparams
        random.seed(
            seed
        )  # initialize internal state of random num generator with fixed seed
        np.random.seed(
            seed
        )  # initialize internal state of numpy random num generator with fixed seed

        data_files_dir = exp_dir + 'data_files/'  # data file dir
        data_filenames = os.listdir(data_files_dir)  # all data files
        algorithm_prefix = 'algorithm_itr_'
        algorithm_filenames = [
            f for f in data_filenames if f.startswith(algorithm_prefix)
        ]  # all algorithm iteration files
        current_algorithm = sorted(
            algorithm_filenames,
            reverse=True)[0]  # current algorithm iteration filename
        current_itr = int(
            current_algorithm[len(algorithm_prefix):len(algorithm_prefix) +
                              2])  # current iteration number

        gps = GPSMain(hyperparams.config)  # initialise GPSMain object
        if hyperparams.config['gui_on']:
            test_policy = threading.Thread(target=lambda: gps.test_policy(
                itr=current_itr, N=test_policy_N
            ))  # define thread target (what is called on 'start' command)
            test_policy.daemon = True  # daemon threads are killed automatically on program termination
            test_policy.start()  # start thread process

            plt.ioff()  # turn interactive mode off
            plt.show()  # start mainloop for displaying plots
        else:
            gps.test_policy(
                itr=current_itr, N=test_policy_N
            )  # else, no seperate thread needed, start process as normal
    else:
        import random
        import numpy as np
        import matplotlib.pyplot as plt

        seed = hyperparams.config.get(
            'random_seed', 0)  # retrieve random_seed value from hyperparams
        random.seed(
            seed
        )  # initialize internal state of random num generator with fixed seed
        np.random.seed(
            seed
        )  # initialize internal state of numpy random num generator with fixed seed

        gps = GPSMain(hyperparams.config,
                      args.quit)  # initialise GPSMain object
        if hyperparams.config['gui_on']:
            run_gps = threading.Thread(
                target=lambda: gps.run(itr_load=resume_training_itr)
            )  # define thread target (what is called on 'start' command)
            run_gps.daemon = True  # daemon threads are killed automatically on program termination
            run_gps.start()  # start thread process

            plt.ioff()  # turn interactive mode off
            plt.show()  # start mainloop for displaying plots
        else:
            gps.run(
                itr_load=resume_training_itr
            )  # else, no seperate thread needed, start process as normal
Пример #5
0
def main():
    """ Main function to be run. """
    parser = argparse.ArgumentParser(
        description='Run the Guided Policy Search algorithm.')
    parser.add_argument('experiment', type=str, help='experiment name')
    parser.add_argument('-n',
                        '--new',
                        action='store_true',
                        help='create new experiment')
    parser.add_argument('-t',
                        '--targetsetup',
                        action='store_true',
                        help='run target setup')
    parser.add_argument('-r',
                        '--resume',
                        metavar='N',
                        type=int,
                        help='resume training from iter N')
    parser.add_argument('-p',
                        '--policy',
                        metavar='N',
                        type=int,
                        help='take N policy samples (for BADMM only)')
    args = parser.parse_args()

    exp_name = args.experiment
    resume_training_itr = args.resume
    test_policy_N = args.policy

    from gps import __file__ as gps_filepath
    gps_dir = '/'.join(str.split(gps_filepath, '/')[:-3]) + '/'
    exp_dir = gps_dir + 'experiments/' + exp_name + '/'
    hyperparams_file = exp_dir + 'hyperparams.py'

    if args.new:
        from shutil import copy

        if os.path.exists(exp_dir):
            sys.exit("Experiment '%s' already exists.\nPlease remove '%s'." %
                     (exp_name, exp_dir))
        os.makedirs(exp_dir)

        prev_exp_file = '.previous_experiment'
        prev_exp_dir = None
        try:
            with open(prev_exp_file, 'r') as f:
                prev_exp_dir = f.readline()
            copy(prev_exp_dir + 'hyperparams.py', exp_dir)
            if os.path.exists(prev_exp_dir + 'targets.npz'):
                copy(prev_exp_dir + 'targets.npz', exp_dir)
        except IOError as e:
            with open(hyperparams_file, 'w') as f:
                f.write(
                    '# To get started, copy over hyperparams from another experiment.\n'
                    +
                    '# Visit rll.berkeley.edu/gps/hyperparams.html for documentation.'
                )
        with open(prev_exp_file, 'w') as f:
            f.write(exp_dir)

        exit_msg = ("Experiment '%s' created.\nhyperparams file: '%s'" %
                    (exp_name, hyperparams_file))
        if prev_exp_dir and os.path.exists(prev_exp_dir):
            exit_msg += "\ncopied from     : '%shyperparams.py'" % prev_exp_dir
        sys.exit(exit_msg)

    if not os.path.exists(hyperparams_file):
        sys.exit("Experiment '%s' does not exist.\nDid you create '%s'?" %
                 (exp_name, hyperparams_file))

    hyperparams = imp.load_source('hyperparams', hyperparams_file)
    if args.targetsetup:
        try:
            import matplotlib.pyplot as plt
            from gps.agent.ros.agent_ros import AgentROS
            from gps.gui.target_setup_gui import TargetSetupGUI

            agent = AgentROS(hyperparams.config['agent'])
            TargetSetupGUI(hyperparams.config['common'], agent)

            plt.ioff()
            plt.show()
        except ImportError:
            sys.exit('ROS required for target setup.')
    elif test_policy_N:
        import random
        import numpy as np
        import matplotlib.pyplot as plt

        random.seed(0)
        np.random.seed(0)

        data_files_dir = exp_dir + 'data_files/'
        data_filenames = os.listdir(data_files_dir)
        algorithm_prefix = 'algorithm_itr_'
        algorithm_filenames = [
            f for f in data_filenames if f.startswith(algorithm_prefix)
        ]
        current_algorithm = sorted(algorithm_filenames, reverse=True)[0]
        current_itr = int(
            current_algorithm[len(algorithm_prefix):len(algorithm_prefix) + 2])

        gps = GPSMain(hyperparams.config)
        if hyperparams.config['gui_on']:
            test_policy = threading.Thread(target=lambda: gps.test_policy(
                itr=current_itr, N=test_policy_N))
            test_policy.daemon = True
            test_policy.start()

            plt.ioff()
            plt.show()
        else:
            gps.test_policy(itr=current_itr, N=test_policy_N)
    else:
        import random
        import numpy as np
        import matplotlib.pyplot as plt

        random.seed(0)
        np.random.seed(0)

        gps = GPSMain(hyperparams.config)
        if hyperparams.config['gui_on']:
            run_gps = threading.Thread(
                target=lambda: gps.run(itr_load=resume_training_itr))
            run_gps.daemon = True
            run_gps.start()

            plt.ioff()
            plt.show()
        else:
            gps.run(itr_load=resume_training_itr)
Пример #6
0
def main():
    """ Main function to be run. """
    parser = argparse.ArgumentParser(
        description='Run the Guided Policy Search algorithm.')
    parser.add_argument('experiment', type=str, help='experiment name')
    parser.add_argument('-n',
                        '--new',
                        action='store_true',
                        help='create new experiment')
    parser.add_argument('-t',
                        '--targetsetup',
                        action='store_true',
                        help='run target setup')
    parser.add_argument('-r',
                        '--resume',
                        metavar='N',
                        type=int,
                        help='resume training from iter N')
    parser.add_argument('-p',
                        '--policy',
                        metavar='N',
                        type=int,
                        help='take N policy samples (for BADMM only)')
    parser.add_argument('-s',
                        '--silent',
                        action='store_true',
                        help='silent debug print outs')
    parser.add_argument('-q',
                        '--quit',
                        action='store_true',
                        help='quit GUI automatically when finished')
    args = parser.parse_args()

    exp_name = args.experiment
    resume_training_itr = args.resume
    test_policy_N = args.policy

    from gps import __file__ as gps_filepath
    gps_filepath = os.path.abspath(gps_filepath)
    gps_dir = '/'.join(str.split(gps_filepath, '/')[:-3]) + '/'
    exp_dir = gps_dir + 'experiments/' + exp_name + '/'
    hyperparams_file = exp_dir + 'hyperparams.py'

    if args.silent:
        logging.basicConfig(format='%(levelname)s:%(message)s',
                            level=logging.INFO)
    else:
        logging.basicConfig(format='%(levelname)s:%(message)s',
                            level=logging.DEBUG)

    if args.new:
        from shutil import copy

        if os.path.exists(exp_dir):
            sys.exit("Experiment '%s' already exists.\nPlease remove '%s'." %
                     (exp_name, exp_dir))
        os.makedirs(exp_dir)

        prev_exp_file = '.previous_experiment'
        prev_exp_dir = None
        try:
            with open(prev_exp_file, 'r') as f:
                prev_exp_dir = f.readline()
            copy(prev_exp_dir + 'hyperparams.py', exp_dir)
            if os.path.exists(prev_exp_dir + 'targets.npz'):
                copy(prev_exp_dir + 'targets.npz', exp_dir)
        except IOError as e:
            with open(hyperparams_file, 'w') as f:
                f.write(
                    '# To get started, copy over hyperparams from another experiment.\n'
                    +
                    '# Visit rll.berkeley.edu/gps/hyperparams.html for documentation.'
                )
        with open(prev_exp_file, 'w') as f:
            f.write(exp_dir)

        exit_msg = ("Experiment '%s' created.\nhyperparams file: '%s'" %
                    (exp_name, hyperparams_file))
        if prev_exp_dir and os.path.exists(prev_exp_dir):
            exit_msg += "\ncopied from     : '%shyperparams.py'" % prev_exp_dir
        sys.exit(exit_msg)

    if not os.path.exists(hyperparams_file):
        sys.exit("Experiment '%s' does not exist.\nDid you create '%s'?" %
                 (exp_name, hyperparams_file))

    hyperparams = imp.load_source('hyperparams', hyperparams_file)

    import matplotlib.pyplot as plt
    import random
    import numpy as np
    random.seed(0)
    np.random.seed(0)

    if args.targetsetup:
        try:
            from gps.agent.ros.agent_ros import AgentROS
            from gps.gui.target_setup_gui import TargetSetupGUI

            agent = AgentROS(hyperparams.config['agent'])
            TargetSetupGUI(hyperparams.config['common'], agent)

            plt.ioff()
            plt.show()
        except ImportError:
            sys.exit('ROS required for target setup.')
    elif test_policy_N:
        data_files_dir = exp_dir + 'data_files/'
        data_filenames = os.listdir(data_files_dir)
        algorithm_prefix = 'algorithm_itr_'
        algorithm_filenames = [
            f for f in data_filenames if f.startswith(algorithm_prefix)
        ]
        current_algorithm = sorted(algorithm_filenames, reverse=True)[0]
        current_itr = int(
            current_algorithm[len(algorithm_prefix):len(algorithm_prefix) + 2])

        gps = GPSMain(hyperparams.config)
        if hyperparams.config['gui_on']:
            test_policy = threading.Thread(target=lambda: gps.test_policy(
                itr=current_itr, N=test_policy_N))
            test_policy.daemon = True
            test_policy.start()

            plt.ioff()
            plt.show()
        else:
            gps.test_policy(itr=current_itr, N=test_policy_N)
    elif exp_name == "mjc_peg_ioc_learning_example":
        ioc_conditions = [np.array([random.choice([np.random.uniform(-0.15, -0.09), np.random.uniform(0.09, 0.15)]), \
                        random.choice([np.random.uniform(-0.15, -0.09), np.random.uniform(0.09, 0.15)])]) for i in xrange(20)]
        top_bottom = [np.array([np.random.uniform(-0.08, 0.08), \
                        random.choice([np.random.uniform(-0.15, -0.09), np.random.uniform(0.09, 0.15)])]) for i in xrange(15)]
        left_right = [np.array([random.choice([np.random.uniform(-0.15, -0.09), np.random.uniform(0.09, 0.15)]), \
                        np.random.uniform(-0.08, 0.08)]) for i in xrange(15)]
        ioc_conditions.extend(top_bottom)
        ioc_conditions.extend(left_right)
        exp_iter = hyperparams.config['algorithm']['iterations']
        data_files_dir = exp_dir + 'data_files/'
        mean_dists = []
        pos_body_offset_dists = [
            np.linalg.norm(ioc_conditions[i])
            for i in xrange(len(ioc_conditions))
        ]
        for i in xrange(len(ioc_conditions)):
            hyperparams = imp.load_source('hyperparams', hyperparams_file)
            # hyperparams.config['gui_on'] = False
            hyperparams.config['algorithm']['ioc_cond'] = ioc_conditions[i]
            gps = GPSMain(hyperparams.config)
            gps.agent._hyperparams['pos_body_offset'] = [ioc_conditions[i]]
            # import pdb; pdb.set_trace()
            if hyperparams.config['gui_on']:
                # run_gps = threading.Thread(
                #     target=lambda: gps.run(itr_load=resume_training_itr)
                # )
                # run_gps.daemon = True
                # run_gps.start()
                gps.run(itr_load=resume_training_itr)
                plt.close()
                # plt.ioff()
                # plt.show()
            else:
                gps.run(itr_load=resume_training_itr)
                # continue
            if i == 0:
                demo_conditions = gps.algorithm.demo_conditions
                failed_conditions = gps.algorithm.failed_conditions
            mean_dists.append(gps.algorithm.dists_to_target[exp_iter - 1][0])
            print "iteration " + repr(i) + ": mean dist is " + repr(
                mean_dists[i])
        with open(exp_dir + 'log.txt', 'a') as f:
            f.write('\nThe 50 IOC conditions are: \n' + str(ioc_conditions) +
                    '\n')
        plt.plot(pos_body_offset_dists, mean_dists, 'ro')
        plt.title("Learning from prior experience using peg insertion")
        plt.xlabel('pos body offset distances to the origin')
        plt.ylabel('mean distances to the target')
        plt.savefig(data_files_dir + 'learning_from_prior.png')
        plt.close()

        from matplotlib.patches import Rectangle

        ioc_conditions_x = [
            ioc_conditions[i][0] for i in xrange(len(ioc_conditions))
        ]
        ioc_conditions_y = [
            ioc_conditions[i][1] for i in xrange(len(ioc_conditions))
        ]
        mean_dists = np.around(mean_dists, decimals=2)
        failed_ioc_x = [
            ioc_conditions_x[i] for i in xrange(len(ioc_conditions))
            if mean_dists[i] > 0.08
        ]
        failed_ioc_y = [
            ioc_conditions_y[i] for i in xrange(len(ioc_conditions))
            if mean_dists[i] > 0.08
        ]
        success_ioc_x = [
            ioc_conditions_x[i] for i in xrange(len(ioc_conditions))
            if mean_dists[i] <= 0.08
        ]
        success_ioc_y = [
            ioc_conditions_y[i] for i in xrange(len(ioc_conditions))
            if mean_dists[i] <= 0.08
        ]
        demo_conditions_x = [
            demo_conditions[i][0] for i in xrange(len(demo_conditions))
        ]
        demo_conditions_y = [
            demo_conditions[i][1] for i in xrange(len(demo_conditions))
        ]
        failed_conditions_x = [
            failed_conditions[i][0] for i in xrange(len(failed_conditions))
        ]
        failed_conditions_y = [
            failed_conditions[i][1] for i in xrange(len(failed_conditions))
        ]
        subplt = plt.subplot()
        subplt.plot(demo_conditions_x, demo_conditions_y, 'go')
        subplt.plot(failed_conditions_x, failed_conditions_y, 'rx')
        subplt.plot(success_ioc_x, success_ioc_y, 'g^')
        subplt.plot(failed_ioc_x, failed_ioc_y, 'rv')
        # plt.legend(['demo_cond', 'failed_badmm', 'success_ioc', 'failed_ioc'], loc= (1, 1))
        for i, txt in enumerate(mean_dists):
            subplt.annotate(txt, (ioc_conditions_x[i], ioc_conditions_y[i]))
        ax = plt.gca()
        ax.add_patch(
            Rectangle((-0.08, -0.08), 0.16, 0.16, fill=False,
                      edgecolor='blue'))
        box = subplt.get_position()
        subplt.set_position(
            [box.x0, box.y0 + box.height * 0.1, box.width, box.height * 0.9])
        subplt.legend(['demo_cond', 'failed_badmm', 'success_ioc', 'failed_ioc'], loc='upper center', bbox_to_anchor=(0.5, -0.05), \
                        shadow=True, ncol=2)
        plt.title(
            "Distribution of neural network and IOC's initial conditions")
        # plt.xlabel('width')
        # plt.ylabel('length')
        plt.savefig(data_files_dir + 'distribution_of_conditions.png')
        plt.show()
    else:
        gps = GPSMain(hyperparams.config)
        if hyperparams.config['gui_on']:
            run_gps = threading.Thread(
                target=lambda: gps.run(itr_load=resume_training_itr))
            run_gps.daemon = True
            run_gps.start()

            plt.ioff()
            plt.show()
        else:
            gps.run(itr_load=resume_training_itr)
Пример #7
0
            exit_msg += "\ncopied from     : '%shyperparams.py'" % prev_exp_dir
        sys.exit(exit_msg)

    if not os.path.exists(hyperparams_file):
        sys.exit("Experiment '%s' does not exist.\nDid you create '%s'?" %
                 (exp_name, hyperparams_file))

    hyperparams = imp.load_source('hyperparams', hyperparams_file)

    if args.targetsetup:
        try:
            import matplotlib.pyplot as plt
            from gps.agent.ros.agent_ros import AgentROS
            from gps.gui.target_setup_gui import TargetSetupGUI

            agent = AgentROS(hyperparams.config['agent'])
            TargetSetupGUI(hyperparams.config['common'], agent)

            # plt.ioff()
            # plt.show()
        except ImportError:
            sys.exit('ROS required for target setup.')
    elif test_policy_N:
        import random
        import numpy as np
        import matplotlib.pyplot as plt

        seed = hyperparams.config.get('random_seed', 0)
        random.seed(seed)
        np.random.seed(seed)
Пример #8
0
def main():
    """ Main function to be run. """
    parser = argparse.ArgumentParser(
        description='Run the Guided Policy Search algorithm.')
    parser.add_argument('experiment', type=str, help='experiment name')
    parser.add_argument('-n',
                        '--new',
                        action='store_true',
                        help='create new experiment')
    parser.add_argument('-t',
                        '--targetsetup',
                        action='store_true',
                        help='run target setup')
    parser.add_argument('-r',
                        '--resume',
                        metavar='N',
                        type=int,
                        help='resume training from iter N')
    parser.add_argument('-p',
                        '--policy',
                        metavar='N',
                        type=int,
                        help='take N policy samples (for BADMM/MDGPS only)')
    parser.add_argument('-s',
                        '--silent',
                        action='store_true',
                        help='silent debug print outs')
    parser.add_argument('-q',
                        '--quit',
                        action='store_true',
                        help='quit GUI automatically when finished')
    args = parser.parse_args()
    #here args means that input we give from the terminal here
    exp_name = args.experiment
    resume_training_itr = args.resume
    test_policy_N = args.policy
    from gps import __file__ as gps_filepath
    #this adds all the files in the directory
    gps_filepath = os.path.abspath(gps_filepath)
    #in this case only the __init__.py is this
    gps_dir = '/'.join(str.split(gps_filepath, '/')[:-3]) + '/'
    #this is the current directory
    exp_dir = gps_dir + 'experiments/' + exp_name + '/'
    hyperparams_file = exp_dir + 'hyperparams.py'
    #here the code goes to the experiments folder in the gps and then the file corresponding
    #to the experiment given in the arg_name
    #hyperparameters are in the hyperparameters file
    if args.silent:
        logging.basicConfig(format='%(levelname)s:%(message)s',
                            level=logging.INFO)
    else:
        logging.basicConfig(format='%(levelname)s:%(message)s',
                            level=logging.DEBUG)

    if args.new:
        from shutil import copy

        if os.path.exists(exp_dir):
            sys.exit("Experiment '%s' already exists.\nPlease remove '%s'." %
                     (exp_name, exp_dir))
        os.makedirs(exp_dir)

        prev_exp_file = '.previous_experiment'
        prev_exp_dir = None
        try:
            with open(prev_exp_file, 'r') as f:
                prev_exp_dir = f.readline()
            copy(prev_exp_dir + 'hyperparams.py', exp_dir)
            if os.path.exists(prev_exp_dir + 'targets.npz'):
                copy(prev_exp_dir + 'targets.npz', exp_dir)
        except IOError as e:
            with open(hyperparams_file, 'w') as f:
                f.write(
                    '# To get started, copy over hyperparams from another experiment.\n'
                    +
                    '# Visit rll.berkeley.edu/gps/hyperparams.html for documentation.'
                )
        with open(prev_exp_file, 'w') as f:
            f.write(exp_dir)

        exit_msg = ("Experiment '%s' created.\nhyperparams file: '%s'" %
                    (exp_name, hyperparams_file))
        if prev_exp_dir and os.path.exists(prev_exp_dir):
            exit_msg += "\ncopied from     : '%shyperparams.py'" % prev_exp_dir
        sys.exit(exit_msg)

    if not os.path.exists(hyperparams_file):
        sys.exit("Experiment '%s' does not exist.\nDid you create '%s'?" %
                 (exp_name, hyperparams_file))

    hyperparams = imp.load_source('hyperparams', hyperparams_file)
    if args.targetsetup:
        try:
            import matplotlib.pyplot as plt
            from gps.agent.ros.agent_ros import AgentROS
            from gps.gui.target_setup_gui import TargetSetupGUI

            agent = AgentROS(hyperparams.config['agent'])
            TargetSetupGUI(hyperparams.config['common'], agent)

            plt.ioff()
            plt.show()
        except ImportError:
            sys.exit('ROS required for target setup.')
    elif test_policy_N:
        import random
        import numpy as np
        import matplotlib.pyplot as plt

        seed = hyperparams.config.get('random_seed', 0)
        random.seed(seed)
        np.random.seed(seed)

        data_files_dir = exp_dir + 'data_files/'
        data_filenames = os.listdir(data_files_dir)
        algorithm_prefix = 'algorithm_itr_'
        algorithm_filenames = [
            f for f in data_filenames if f.startswith(algorithm_prefix)
        ]
        current_algorithm = sorted(algorithm_filenames, reverse=True)[0]
        current_itr = int(
            current_algorithm[len(algorithm_prefix):len(algorithm_prefix) + 2])

        gps = GPSMain(hyperparams.config)
        if hyperparams.config['gui_on']:
            test_policy = threading.Thread(target=lambda: gps.test_policy(
                itr=current_itr, N=test_policy_N))
            test_policy.daemon = True
            test_policy.start()

            plt.ioff()
            plt.show()
        else:
            gps.test_policy(itr=current_itr, N=test_policy_N)
    else:
        import random
        import numpy as np
        import matplotlib.pyplot as plt
        seed = hyperparams.config.get('random_seed', 0)
        random.seed(seed)
        np.random.seed(seed)
        #here the GPS main here the hyperparameters are collection of the dictionary
        #refer the hyperparamers file in the experiments folder
        # following link gives info regarding the hyperparameters.
        # https://github.com/cbfinn/gps/blob/master/docs/hyperparams.md
        gps = GPSMain(hyperparams.config, args.quit)
        if hyperparams.config['gui_on']:
            run_gps = threading.Thread(
                target=lambda: gps.run(itr_load=resume_training_itr))
            run_gps.daemon = True
            run_gps.start()

            plt.ioff()
            plt.show()
        else:
            gps.run(itr_load=resume_training_itr)