Exemplo n.º 1
0
def main():
    parser = argparse.ArgumentParser(description="Hyperparameter search for implemented RL models")
    parser.add_argument('--optimizer', default='hyperband', choices=['hyperband', 'hyperopt'], type=str,
                        help='The hyperparameter optimizer to choose from')
    parser.add_argument('--algo', default='ppo2', choices=list(registered_rl.keys()), help='OpenAI baseline to use',
                        type=str)
    parser.add_argument('--env', type=str, help='environment ID', default='KukaButtonGymEnv-v0',
                        choices=list(registered_env.keys()))
    parser.add_argument('--seed', type=int, default=0, help='random seed (default: 0)')
    parser.add_argument('--srl-model', type=str, default='raw_pixels', choices=list(registered_srl.keys()),
                        help='SRL model to use')
    parser.add_argument('--num-timesteps', type=int, default=1e6, help='number of timesteps the baseline should run')
    parser.add_argument('-v', '--verbose', action='store_true', default=False, help='Display baseline STDOUT')
    parser.add_argument('--max-eval', type=int, default=100, help='Number of evalutation to try for hyperopt')

    args, train_args = parser.parse_known_args()
    args.log_dir = "logs/_{}_search/".format(args.optimizer)

    train_args.extend(['--srl-model', args.srl_model, '--seed', str(args.seed), '--algo', args.algo, '--env', args.env,
                       '--log-dir', args.log_dir, '--no-vis'])

    # verify the algorithm has defined it, and that it returnes an expected value
    try:
        opt_param = registered_rl[args.algo][0].getOptParam()
        assert opt_param is not None
    except AttributeError or AssertionError:
        raise AssertionError("Error: {} algo does not support hyperparameter search.".format(args.algo))

    if args.optimizer == "hyperband":
        opt = Hyperband(opt_param, makeRlTrainingFunction(args, train_args), seed=args.seed,
                        max_iter=args.num_timesteps // ITERATION_SCALE)
    elif args.optimizer == "hyperopt":
        opt = Hyperopt(opt_param, makeRlTrainingFunction(args, train_args), seed=args.seed, num_eval=args.max_eval)
    else:
        raise ValueError("Error: optimizer {} was defined but not implemented, Halting.".format(args.optimizer))

    t_start = time.time()
    opt.run()
    all_params, loss = zip(*opt.history)
    idx = np.argmin(loss)
    opt_params, nb_iter = all_params[idx]
    reward = loss[idx]
    print('\ntime to run : {}s'.format(int(time.time() - t_start)))
    print('Total nb. evaluations : {}'.format(len(all_params)))
    if nb_iter is not None:
        print('Best nb. of iterations : {}'.format(int(nb_iter)))
    print('Best params : ')
    pprint.pprint(opt_params)
    print('Best reward : {:.3f}'.format(-reward))

    param_dict, timesteps = zip(*all_params)
    output = pd.DataFrame(list(param_dict))
    # make sure we returned a timestep value to log, otherwise ignore
    if not any([el is None for el in timesteps]):
        output["timesteps"] = np.array(np.maximum(MIN_ITERATION, np.array(timesteps) * ITERATION_SCALE).astype(int))
    output["reward"] = -np.array(loss)
    output.to_csv("logs/{}_{}_{}_{}_seed{}_numtimestep{}.csv"
                  .format(args.optimizer, args.algo, args.env, args.srl_model, args.seed, args.num_timesteps))
Exemplo n.º 2
0
def main():
    # Global variables for callback
    global ENV_NAME, ALGO, ALGO_NAME, LOG_INTERVAL, VISDOM_PORT, viz
    global SAVE_INTERVAL, EPISODE_WINDOW, MIN_EPISODES_BEFORE_SAVE
    parser = argparse.ArgumentParser(
        description="Train script for RL algorithms")
    parser.add_argument('--algo',
                        default='ppo2',
                        choices=list(registered_rl.keys()),
                        help='RL algo to use',
                        type=str)
    parser.add_argument('--env',
                        type=str,
                        help='environment ID',
                        default='KukaButtonGymEnv-v0',
                        choices=list(registered_env.keys()))
    parser.add_argument('--seed',
                        type=int,
                        default=0,
                        help='random seed (default: 0)')
    parser.add_argument(
        '--episode-window',
        type=int,
        default=40,
        help='Episode window for moving average plot (default: 40)')
    parser.add_argument(
        '--log-dir',
        default='/tmp/gym/',
        type=str,
        help='directory to save agent logs and model (default: /tmp/gym)')
    parser.add_argument('--num-timesteps', type=int, default=int(1e6))
    parser.add_argument('--srl-model',
                        type=str,
                        default='raw_pixels',
                        choices=list(registered_srl.keys()),
                        help='SRL model to use')
    parser.add_argument('--num-stack',
                        type=int,
                        default=1,
                        help='number of frames to stack (default: 1)')
    parser.add_argument(
        '--action-repeat',
        type=int,
        default=1,
        help='number of times an action will be repeated (default: 1)')
    parser.add_argument('--port',
                        type=int,
                        default=8097,
                        help='visdom server port (default: 8097)')
    parser.add_argument('--no-vis',
                        action='store_true',
                        default=False,
                        help='disables visdom visualization')
    parser.add_argument(
        '--shape-reward',
        action='store_true',
        default=False,
        help='Shape the reward (reward = - distance) instead of a sparse reward'
    )
    parser.add_argument('-c',
                        '--continuous-actions',
                        action='store_true',
                        default=False)
    parser.add_argument(
        '-joints',
        '--action-joints',
        action='store_true',
        default=False,
        help=
        'set actions to the joints of the arm directly, instead of inverse kinematics'
    )
    parser.add_argument('-r',
                        '--random-target',
                        action='store_true',
                        default=False,
                        help='Set the button to a random position')
    parser.add_argument(
        '--srl-config-file',
        type=str,
        default="config/srl_models.yaml",
        help='Set the location of the SRL model path configuration.')
    parser.add_argument('--hyperparam', type=str, nargs='+', default=[])
    parser.add_argument('--min-episodes-save',
                        type=int,
                        default=100,
                        help="Min number of episodes before saving best model")
    parser.add_argument(
        '--latest',
        action='store_true',
        default=False,
        help=
        'load the latest learned model (location:srl_zoo/logs/DatasetName/)')
    parser.add_argument(
        '--load-rl-model-path',
        type=str,
        default=None,
        help="load the trained RL model, should be with the same algorithm type"
    )
    parser.add_argument(
        '-sc',
        '--simple-continual',
        action='store_true',
        default=False,
        help=
        'Simple red square target for task 1 of continual learning scenario. '
        + 'The task is: robot should reach the target.')
    parser.add_argument(
        '-cc',
        '--circular-continual',
        action='store_true',
        default=False,
        help='Blue square target for task 2 of continual learning scenario. ' +
        'The task is: robot should turn in circle around the target.')
    parser.add_argument(
        '-sqc',
        '--square-continual',
        action='store_true',
        default=False,
        help='Green square target for task 3 of continual learning scenario. '
        + 'The task is: robot should turn in square around the target.')
    parser.add_argument(
        '-ec',
        '--eight-continual',
        action='store_true',
        default=False,
        help='Green square target for task 4 of continual learning scenario. '
        +
        'The task is: robot should do the eigth with the target as center of the shape.'
    )
    parser.add_argument('--teacher-data-folder',
                        type=str,
                        default="",
                        help='Dataset folder of the teacher(s) policy(ies)',
                        required=False)
    parser.add_argument(
        '--epochs-distillation',
        type=int,
        default=30,
        metavar='N',
        help='number of epochs to train for distillation(default: 30)')
    parser.add_argument(
        '--distillation-training-set-size',
        type=int,
        default=-1,
        help='Limit size (number of samples) of the training set (default: -1)'
    )
    parser.add_argument(
        '--perform-cross-evaluation-cc',
        action='store_true',
        default=False,
        help='A cross evaluation from the latest stored model to all tasks')
    parser.add_argument(
        '--eval-episode-window',
        type=int,
        default=400,
        metavar='N',
        help=
        'Episode window for saving each policy checkpoint for future distillation(default: 100)'
    )
    parser.add_argument(
        '--new-lr',
        type=float,
        default=1.e-4,
        help="New learning rate ratio to train a pretrained agent")
    parser.add_argument('--img-shape',
                        type=str,
                        default="(3,64,64)",
                        help="Image shape of environment.")
    parser.add_argument(
        "--gpu-num",
        help="Choose the number of GPU (CUDA_VISIBLE_DEVICES).",
        type=str,
        default="1",
        choices=["0", "1", "2", "3", "5", "6", "7", "8"])
    parser.add_argument("--srl-model-path",
                        help="SRL model weights path",
                        type=str,
                        default=None)
    parser.add_argument(
        "--relative-pos",
        action='store_true',
        default=False,
        help="For 'ground_truth': use relative position or not.")
    # Ignore unknown args for now
    args, unknown = parser.parse_known_args()
    # os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num
    env_kwargs = {}
    if args.img_shape is None:
        img_shape = None  #(3,224,224)
    else:
        img_shape = tuple(map(int, args.img_shape[1:-1].split(",")))
    env_kwargs['img_shape'] = img_shape
    # LOAD SRL models list
    assert os.path.exists(args.srl_config_file), \
        "Error: cannot load \"--srl-config-file {}\", file not found!".format(args.srl_config_file)
    with open(args.srl_config_file, 'rb') as f:
        all_models = yaml.load(f)
    # Sanity check
    assert args.episode_window >= 1, "Error: --episode_window cannot be less than 1"
    assert args.num_timesteps >= 1, "Error: --num-timesteps cannot be less than 1"
    assert args.num_stack >= 1, "Error: --num-stack cannot be less than 1"
    assert args.action_repeat >= 1, "Error: --action-repeat cannot be less than 1"
    assert 0 <= args.port < 65535, "Error: invalid visdom port number {}, ".format(args.port) + \
                                   "port number must be an unsigned 16bit number [0,65535]."
    assert registered_srl[args.srl_model][0] == SRLType.ENVIRONMENT or args.env in all_models, \
        "Error: the environment {} has no srl_model defined in 'srl_models.yaml'. Cannot continue.".format(args.env)
    # check that all the SRL_model can be run on the environment
    if registered_srl[args.srl_model][1] is not None:
        found = False
        for compatible_class in registered_srl[args.srl_model][1]:
            if issubclass(compatible_class, registered_env[args.env][0]):
                found = True
                break
        assert found, "Error: srl_model {}, is not compatible with the {} environment.".format(
            args.srl_model, args.env)

    assert not(sum([args.simple_continual, args.circular_continual, args.square_continual, args.eight_continual]) \
           > 1 and args.env == "OmnirobotEnv-v0"), \
        "For continual SRL and RL, please provide only one scenario at the time and use OmnirobotEnv-v0 environment !"

    assert not(args.algo == "distillation" and (args.teacher_data_folder == '' or args.continuous_actions is True)), \
        "For performing policy distillation, make sure use specify a valid teacher dataset and discrete actions !"

    ENV_NAME = args.env
    ALGO_NAME = args.algo
    VISDOM_PORT = args.port
    EPISODE_WINDOW = args.episode_window
    MIN_EPISODES_BEFORE_SAVE = args.min_episodes_save
    CROSS_EVAL = args.perform_cross_evaluation_cc
    EPISODE_WINDOW_DISTILLATION_WIN = args.eval_episode_window
    NEW_LR = args.new_lr
    print("EPISODE_WINDOW_DISTILLATION_WIN: ", EPISODE_WINDOW_DISTILLATION_WIN)

    if args.no_vis:
        viz = False

    algo_class, algo_type, action_type = registered_rl[args.algo]
    algo = algo_class()
    ALGO = algo

    # if callback frequency needs to be changed
    LOG_INTERVAL = algo.LOG_INTERVAL
    SAVE_INTERVAL = algo.SAVE_INTERVAL

    if not args.continuous_actions and ActionType.DISCRETE not in action_type:
        raise ValueError(
            args.algo +
            " does not support discrete actions, please use the '--continuous-actions' "
            + "(or '-c') flag.")
    if args.continuous_actions and ActionType.CONTINUOUS not in action_type:
        raise ValueError(
            args.algo +
            " does not support continuous actions, please remove the '--continuous-actions' "
            + "(or '-c') flag.")

    env_kwargs["is_discrete"] = not args.continuous_actions

    printGreen("\nAgent = {} \n".format(args.algo))

    env_kwargs["action_repeat"] = args.action_repeat
    # Random init position for button
    env_kwargs["random_target"] = args.random_target

    # If in simple continual scenario, then the target should be initialized randomly.
    if args.simple_continual is True:
        env_kwargs["random_target"] = True

    # Allow up action
    # env_kwargs["force_down"] = False

    # allow multi-view
    env_kwargs['multi_view'] = args.srl_model == "multi_view_srl"
    parser = algo.customArguments(parser)
    args = parser.parse_args()

    args, env_kwargs = configureEnvAndLogFolder(args, env_kwargs, all_models)
    args_dict = filterJSONSerializableObjects(vars(args))
    # Save args
    with open(LOG_DIR + "args.json", "w") as f:
        json.dump(args_dict, f)

    env_class = registered_env[args.env][0]
    # env default kwargs
    default_env_kwargs = {
        k: v.default
        for k, v in inspect.signature(env_class.__init__).parameters.items()
        if v is not None
    }

    globals_env_param = sys.modules[env_class.__module__].getGlobals()
    ### HACK way to reset image shape !!
    globals_env_param['RENDER_HEIGHT'] = img_shape[1]
    globals_env_param['RENDER_WIDTH'] = img_shape[2]
    globals_env_param['RELATIVE_POS'] = args.relative_pos

    super_class = registered_env[args.env][1]
    # reccursive search through all the super classes of the asked environment, in order to get all the arguments.
    rec_super_class_lookup = {
        dict_class: dict_super_class
        for _, (dict_class, dict_super_class, _, _) in registered_env.items()
    }
    while super_class != SRLGymEnv:
        assert super_class in rec_super_class_lookup, "Error: could not find super class of {}".format(super_class) + \
                                                      ", are you sure \"registered_env\" is correctly defined?"
        super_env_kwargs = {
            k: v.default
            for k, v in inspect.signature(
                super_class.__init__).parameters.items() if v is not None
        }
        default_env_kwargs = {**super_env_kwargs, **default_env_kwargs}

        globals_env_param = {
            **sys.modules[super_class.__module__].getGlobals(),
            **globals_env_param
        }

        super_class = rec_super_class_lookup[super_class]

    # Print Variables
    printYellow("Arguments:")
    pprint(args_dict)
    printYellow("Env Globals:")
    pprint(
        filterJSONSerializableObjects({
            **globals_env_param,
            **default_env_kwargs,
            **env_kwargs
        }))
    # Save env params
    saveEnvParams(globals_env_param, {**default_env_kwargs, **env_kwargs})
    # Seed tensorflow, python and numpy random generator
    set_global_seeds(args.seed)
    # Augment the number of timesteps (when using mutliprocessing this number is not reached)
    args.num_timesteps = int(1.1 * args.num_timesteps)
    # Get the hyperparameter, if given (Hyperband)
    hyperparams = {
        param.split(":")[0]: param.split(":")[1]
        for param in args.hyperparam
    }
    hyperparams = algo.parserHyperParam(hyperparams)

    if args.load_rl_model_path is not None:
        #use a small learning rate
        print("use a small learning rate: {:f}".format(1.0e-4))
        hyperparams["learning_rate"] = lambda f: f * 1.0e-4

    # Train the agent
    if args.load_rl_model_path is not None:
        algo.setLoadPath(args.load_rl_model_path)
    algo.train(args, callback, env_kwargs=env_kwargs, train_kwargs=hyperparams)
Exemplo n.º 3
0
def main():
    parser = argparse.ArgumentParser(
        description="OpenAI RL Baselines Benchmark",
        epilog=
        'After the arguments are parsed, the rest are assumed to be arguments for'
        + ' rl_baselines.train')
    parser.add_argument('--algo',
                        type=str,
                        default='ppo2',
                        help='OpenAI baseline to use',
                        choices=list(registered_rl.keys()))
    parser.add_argument('--env',
                        type=str,
                        nargs='+',
                        default=["KukaButtonGymEnv-v0"],
                        help='environment ID(s)',
                        choices=list(registered_env.keys()))
    parser.add_argument('--srl-model',
                        type=str,
                        nargs='+',
                        default=["raw_pixels"],
                        help='SRL model(s) to use',
                        choices=list(registered_srl.keys()))
    parser.add_argument('--num-timesteps',
                        type=int,
                        default=1e6,
                        help='number of timesteps the baseline should run')
    parser.add_argument('-v',
                        '--verbose',
                        action='store_true',
                        default=False,
                        help='Display baseline STDOUT')
    parser.add_argument(
        '--num-iteration',
        type=int,
        default=15,
        help=
        'number of time each algorithm should be run for each unique combination of environment '
        + ' and srl-model.')
    parser.add_argument(
        '--seed',
        type=int,
        default=0,
        help=
        'initial seed for each unique combination of environment and srl-model.'
    )
    parser.add_argument(
        '--srl-config-file',
        type=str,
        default="config/srl_models.yaml",
        help='Set the location of the SRL model path configuration.')

    # returns the parsed arguments, and the rest are assumed to be arguments for rl_baselines.train
    args, train_args = parser.parse_known_args()

    # Sanity check
    assert args.num_timesteps >= 1, "Error: --num-timesteps cannot be less than 1"
    assert args.num_iteration >= 1, "Error: --num-iteration cannot be less than 1"

    # Removing duplicates and sort
    srl_models = list(set(args.srl_model))
    envs = list(set(args.env))
    srl_models.sort()
    envs.sort()

    # LOAD SRL models list
    assert os.path.exists(args.srl_config_file), \
        "Error: cannot load \"--srl-config-file {}\", file not found!".format(args.srl_config_file)
    with open(args.srl_config_file, 'rb') as f:
        all_models = yaml.load(f)

    # Checking definition and presence of all requested srl_models
    valid = True
    for env in envs:
        # validated the env definition
        if env not in all_models:
            printRed(
                "Error: 'srl_models.yaml' missing definition for environment {}"
                .format(env))
            valid = False
            continue  # skip to the next env, this one is not valid

        # checking log_folder for current env
        missing_log = "log_folder" not in all_models[env]
        if missing_log:
            printRed(
                "Error: 'srl_models.yaml' missing definition for log_folder in environment {}"
                .format(env))
            valid = False

        # validate each model for the current env definition
        for model in srl_models:
            if registered_srl[model][0] == SRLType.ENVIRONMENT:
                continue  # not an srl model, skip to the next model
            elif model not in all_models[env]:
                printRed(
                    "Error: 'srl_models.yaml' missing srl_model {} for environment {}"
                    .format(model, env))
                valid = False
            elif (not missing_log) and (
                    not os.path.exists(all_models[env]["log_folder"] +
                                       all_models[env][model])):
                # checking presence of srl_model path, if and only if log_folder exists
                printRed(
                    "Error: srl_model {} for environment {} was defined in ".
                    format(model, env) +
                    "'srl_models.yaml', however the file {} it was tagetting does not exist."
                    .format(all_models[env]["log_folder"] +
                            all_models[env][model]))
                valid = False

    assert valid, "Errors occured due to malformed 'srl_models.yaml', cannot continue."

    # check that all the SRL_models can be run on all the environments
    valid = True
    for env in envs:
        for model in srl_models:
            if registered_srl[model][1] is not None:
                found = False
                for compatible_class in registered_srl[model][1]:
                    if issubclass(compatible_class, registered_env[env][0]):
                        found = True
                        break
                if not found:
                    valid = False
                    printRed(
                        "Error: srl_model {}, is not compatible with the {} environment."
                        .format(model, env))
    assert valid, "Errors occured due to an incompatible combination of srl_model and environment, cannot continue."

    # the seeds used in training the baseline.
    seeds = list(np.arange(args.num_iteration) + args.seed)

    if args.verbose:
        # None here means stdout of terminal for subprocess.call
        stdout = None
    else:
        stdout = open(os.devnull, 'w')

    printGreen("\nRunning {} benchmarks {} times...".format(
        args.algo, args.num_iteration))
    print("\nSRL-Models:\t{}".format(srl_models))
    print("environments:\t{}".format(envs))
    print("verbose:\t{}".format(args.verbose))
    print("timesteps:\t{}".format(args.num_timesteps))
    for model in srl_models:
        for env in envs:
            for i in range(args.num_iteration):

                printGreen(
                    "\nIteration_num={} (seed: {}), Environment='{}', SRL-Model='{}'"
                    .format(i, seeds[i], env, model))

                # redefine the parsed args for rl_baselines.train
                loop_args = [
                    '--srl-model', model, '--seed',
                    str(seeds[i]), '--algo', args.algo, '--env', env,
                    '--num-timesteps',
                    str(int(args.num_timesteps)), '--srl-config-file',
                    args.srl_config_file
                ]

                ok = subprocess.call(['python', '-m', 'rl_baselines.train'] +
                                     train_args + loop_args,
                                     stdout=stdout)

                if ok != 0:
                    # throw the error down to the terminal
                    raise ChildProcessError(
                        "An error occured, error code: {}".format(ok))
Exemplo n.º 4
0
def main():
    # Global variables for callback
    parser = argparse.ArgumentParser(
        description=
        "Evaluation script for distillation from two teacher policies")
    parser.add_argument('--seed',
                        type=int,
                        default=0,
                        help='random seed (default: 0)')
    parser.add_argument('--env',
                        type=str,
                        help='environment ID',
                        default='OmnirobotEnv-v0',
                        choices=list(registered_env.keys()))
    parser.add_argument(
        '--episode_window',
        type=int,
        default=40,
        help='Episode window for moving average plot (default: 40)')
    parser.add_argument(
        '--log-dir-teacher-one',
        default='/tmp/gym/',
        type=str,
        help=
        'directory to load an optmimal agent for task 1 (default: /tmp/gym)')
    parser.add_argument(
        '--log-dir-teacher-two',
        default='/tmp/gym/',
        type=str,
        help=
        'directory to load an optmimal agent for task 2 (default: /tmp/gym)')
    parser.add_argument(
        '--log-dir-student',
        default='/tmp/gym/',
        type=str,
        help=
        'directory to save the student agent logs and model (default: /tmp/gym)'
    )
    parser.add_argument(
        '--srl-config-file-one',
        type=str,
        default="config/srl_models_one.yaml",
        help='Set the location of the SRL model path configuration.')
    parser.add_argument(
        '--srl-config-file-two',
        type=str,
        default="config/srl_models_two.yaml",
        help='Set the location of the SRL model path configuration.')
    parser.add_argument(
        '--epochs-distillation',
        type=int,
        default=30,
        metavar='N',
        help='number of epochs to train for distillation(default: 30)')
    parser.add_argument(
        '--distillation-training-set-size',
        type=int,
        default=-1,
        help='Limit size (number of samples) of the training set (default: -1)'
    )
    parser.add_argument(
        '--eval-tasks',
        type=str,
        nargs='+',
        default=['cc', 'sqc', 'sc'],
        help='A cross evaluation from the latest stored model to all tasks')
    parser.add_argument(
        '--continual-learning-labels',
        type=str,
        nargs=2,
        metavar=('label_1', 'label_2'),
        default=argparse.SUPPRESS,
        help='Labels for the continual learning RL distillation task.')
    parser.add_argument('--student-srl-model',
                        type=str,
                        default='raw_pixels',
                        choices=list(registered_srl.keys()),
                        help='SRL model to use for the student RL policy')
    parser.add_argument(
        '--epochs-teacher-datasets',
        type=int,
        default=30,
        metavar='N',
        help=
        'number of epochs for generating both RL teacher datasets (default: 30)'
    )
    parser.add_argument(
        '--num-iteration',
        type=int,
        default=1,
        help='number of time each algorithm should be run the eval (N seeds).')
    parser.add_argument(
        '--eval-episode-window',
        type=int,
        default=400,
        metavar='N',
        help=
        'Episode window for saving each policy checkpoint for future distillation(default: 100)'
    )

    args, unknown = parser.parse_known_args()

    if 'continual_learning_labels' in args:
        assert args.continual_learning_labels[0] in CONTINUAL_LEARNING_LABELS and args.continual_learning_labels[1] \
               in CONTINUAL_LEARNING_LABELS, "Please specify a valid Continual learning label to each dataset to be " \
                                             "used for RL distillation !"
    print(args.continual_learning_labels)
    assert os.path.exists(args.srl_config_file_one), \
        "Error: cannot load \"--srl-config-file {}\", file not found!".format(args.srl_config_file_one)

    assert os.path.exists(args.srl_config_file_two), \
        "Error: cannot load \"--srl-config-file {}\", file not found!".format(args.srl_config_file_two)
    if not (args.log_dir_teacher_one == "None"):
        assert os.path.exists(args.log_dir_teacher_one), \
            "Error: cannot load \"--srl-config-file {}\", file not found!".format(args.log_dir_teacher_one)
    assert os.path.exists(args.log_dir_teacher_two), \
        "Error: cannot load \"--srl-config-file {}\", file not found!".format(args.srl_config_file_two)

    teacher_pro = args.log_dir_teacher_one
    teacher_learn = args.log_dir_teacher_two

    # The output path generate from the
    teacher_pro_data = args.continual_learning_labels[0] + '/'
    teacher_learn_data = args.continual_learning_labels[1] + '/'
    merge_path = "data/on_policy_merged"

    print(teacher_pro_data, teacher_learn_data)
    episodes, policy_path = allPolicy(teacher_learn)

    rewards_at_episode = {}
    episodes_to_test = [
        e for e in episodes if (int(e) < 2000 and int(e) % 200 == 0) or (
            int(e) > 2000 and int(e) % 1000 == 0)
    ]

    # generate data from Professional teacher
    printYellow("\nGenerating on policy for optimal teacher: " +
                args.continual_learning_labels[0])

    if not (args.log_dir_teacher_one == "None"):
        OnPolicyDatasetGenerator(teacher_pro,
                                 args.continual_learning_labels[0] + '_copy/',
                                 task_id=args.continual_learning_labels[0],
                                 num_eps=args.epochs_teacher_datasets,
                                 episode=-1,
                                 env_name=args.env)
    print("Eval on eps list: ", episodes_to_test)
    for eps in episodes_to_test:
        student_path = args.log_dir_student
        printBlue("\n\nEvaluation at episode " + str(eps))

        if not (args.log_dir_teacher_one == "None"):
            # Use a copy of the optimal teacher
            ok = subprocess.call([
                'cp', '-r',
                'data/' + args.continual_learning_labels[0] + '_copy/',
                'data/' + teacher_pro_data, '-f'
            ])
            assert ok == 0
            time.sleep(2)

        # Generate data from learning teacher
        printYellow("\nGenerating on-policy data from the optimal teacher: " +
                    args.continual_learning_labels[1])
        OnPolicyDatasetGenerator(teacher_learn,
                                 teacher_learn_data,
                                 task_id=args.continual_learning_labels[1],
                                 episode=eps,
                                 num_eps=args.epochs_teacher_datasets,
                                 env_name=args.env)

        if args.log_dir_teacher_one == "None":
            merge_path = 'data/' + teacher_learn_data
            ok = subprocess.call(
                ['cp', '-r', merge_path, 'srl_zoo/data/', '-f'])
        else:
            # merge the data
            mergeData('data/' + teacher_pro_data,
                      'data/' + teacher_learn_data,
                      merge_path,
                      force=True)

            ok = subprocess.call(
                ['cp', '-r', 'data/on_policy_merged/', 'srl_zoo/data/', '-f'])
        assert ok == 0
        time.sleep(2)

        # Train a policy with distillation on the merged teacher's datasets
        trainStudent('srl_zoo/' + merge_path,
                     args.continual_learning_labels[1],
                     yaml_file=args.srl_config_file_one,
                     log_dir=args.log_dir_student,
                     srl_model=args.student_srl_model,
                     env_name=args.env,
                     training_size=args.distillation_training_set_size,
                     epochs=args.epochs_distillation)
        student_path += args.env + '/' + args.student_srl_model + "/distillation/"
        latest_student_path = max([
            student_path + "/" + d for d in os.listdir(student_path)
            if os.path.isdir(student_path + "/" + d)
        ],
                                  key=os.path.getmtime) + '/'
        rewards = {}
        printRed("\nSaving the student at path: " + latest_student_path)
        for task_label in ["-sc", "-cc"]:
            rewards[task_label] = []

            for seed_i in range(args.num_iteration):
                printYellow("\nEvaluating student on task: " + task_label +
                            " for seed: " + str(seed_i))
                command_line_enjoy_student = [
                    'python', '-m', 'replay.enjoy_baselines',
                    '--num-timesteps', '251', '--log-dir', latest_student_path,
                    task_label, "--seed",
                    str(seed_i)
                ]
                ok = subprocess.check_output(command_line_enjoy_student)
                ok = ok.decode('utf-8')
                str_before = "Mean reward: "
                str_after = "\npybullet"
                idx_before = ok.find(str_before) + len(str_before)
                idx_after = ok.find(str_after)
                seed_reward = float(ok[idx_before:idx_after])
                rewards[task_label].append(seed_reward)
        print("rewards at eps ", eps, ": ", rewards)
        rewards_at_episode[eps] = rewards
    print("All rewards: ", rewards_at_episode)
    json_dict = json.dumps(rewards_at_episode)
    json_dict_name = \
        args.log_dir_student + "/reward_at_episode_" + datetime.datetime.now().strftime("%y-%m-%d_%Hh%M_%S") + '.json'
    f = open(json_dict_name, "w")
    f.write(json_dict)
    f.close()
    printRed("\nSaving the evalation at path: " + json_dict_name)
Exemplo n.º 5
0
def main():
    # Global variables for callback
    global ENV_NAME, ALGO, ALGO_NAME, LOG_INTERVAL, VISDOM_PORT, viz
    global SAVE_INTERVAL, EPISODE_WINDOW, MIN_EPISODES_BEFORE_SAVE
    parser = argparse.ArgumentParser(description="Train script for RL algorithms")
    parser.add_argument('--algo', default='ppo2', choices=list(registered_rl.keys()), help='RL algo to use',
                        type=str)
    parser.add_argument('--env', type=str, help='environment ID', default='KukaButtonGymEnv-v0',
                        choices=list(registered_env.keys()))
    parser.add_argument('--seed', type=int, default=0, help='random seed (default: 0)')
    parser.add_argument('--episode_window', type=int, default=40,
                        help='Episode window for moving average plot (default: 40)')
    parser.add_argument('--log-dir', default='/tmp/gym/', type=str,
                        help='directory to save agent logs and model (default: /tmp/gym)')
    parser.add_argument('--num-timesteps', type=int, default=int(1e6))
    parser.add_argument('--srl-model', type=str, default='raw_pixels', choices=list(registered_srl.keys()),
                        help='SRL model to use')
    parser.add_argument('--num-stack', type=int, default=1, help='number of frames to stack (default: 1)')
    parser.add_argument('--action-repeat', type=int, default=1,
                        help='number of times an action will be repeated (default: 1)')
    parser.add_argument('--port', type=int, default=8097, help='visdom server port (default: 8097)')
    parser.add_argument('--no-vis', action='store_true', default=False, help='disables visdom visualization')
    parser.add_argument('--shape-reward', action='store_true', default=False,
                        help='Shape the reward (reward = - distance) instead of a sparse reward')
    parser.add_argument('-c', '--continuous-actions', action='store_true', default=False)
    parser.add_argument('-joints', '--action-joints', action='store_true', default=False,
                        help='set actions to the joints of the arm directly, instead of inverse kinematics')
    parser.add_argument('-r', '--random-target', action='store_true', default=False,
                        help='Set the button to a random position')
    parser.add_argument('--srl-config-file', type=str, default="config/srl_models.yaml",
                        help='Set the location of the SRL model path configuration.')
    parser.add_argument('--hyperparam', type=str, nargs='+', default=[])
    parser.add_argument('--min-episodes-save', type=int, default=100,
                        help="Min number of episodes before saving best model")
    parser.add_argument('--latest', action='store_true', default=False,
                        help='load the latest learned model (location:srl_zoo/logs/DatasetName/)')
    parser.add_argument('--load-rl-model-path', type=str, default=None,
                        help="load the trained RL model, should be with the same algorithm type")
    
    # Ignore unknown args for now
    args, unknown = parser.parse_known_args()
    env_kwargs = {}

    # LOAD SRL models list
    assert os.path.exists(args.srl_config_file), \
        "Error: cannot load \"--srl-config-file {}\", file not found!".format(args.srl_config_file)
    with open(args.srl_config_file, 'rb') as f:
        all_models = yaml.load(f)

    # Sanity check
    assert args.episode_window >= 1, "Error: --episode_window cannot be less than 1"
    assert args.num_timesteps >= 1, "Error: --num-timesteps cannot be less than 1"
    assert args.num_stack >= 1, "Error: --num-stack cannot be less than 1"
    assert args.action_repeat >= 1, "Error: --action-repeat cannot be less than 1"
    assert 0 <= args.port < 65535, "Error: invalid visdom port number {}, ".format(args.port) + \
                                   "port number must be an unsigned 16bit number [0,65535]."
    assert registered_srl[args.srl_model][0] == SRLType.ENVIRONMENT or args.env in all_models, \
        "Error: the environment {} has no srl_model defined in 'srl_models.yaml'. Cannot continue.".format(args.env)
    # check that all the SRL_model can be run on the environment
    if registered_srl[args.srl_model][1] is not None:
        found = False
        for compatible_class in registered_srl[args.srl_model][1]:
            if issubclass(compatible_class, registered_env[args.env][0]):
                found = True
                break
        assert found, "Error: srl_model {}, is not compatible with the {} environment.".format(args.srl_model, args.env)

    ENV_NAME = args.env
    ALGO_NAME = args.algo
    VISDOM_PORT = args.port
    EPISODE_WINDOW = args.episode_window
    MIN_EPISODES_BEFORE_SAVE = args.min_episodes_save

    if args.no_vis:
        viz = False

    algo_class, algo_type, action_type = registered_rl[args.algo]
    algo = algo_class()
    ALGO = algo
    

    # if callback frequency needs to be changed
    LOG_INTERVAL = algo.LOG_INTERVAL
    SAVE_INTERVAL = algo.SAVE_INTERVAL

    if not args.continuous_actions and ActionType.DISCRETE not in action_type:
        raise ValueError(args.algo + " does not support discrete actions, please use the '--continuous-actions' " +
                         "(or '-c') flag.")
    if args.continuous_actions and ActionType.CONTINUOUS not in action_type:
        raise ValueError(args.algo + " does not support continuous actions, please remove the '--continuous-actions' " +
                         "(or '-c') flag.")

    env_kwargs["is_discrete"] = not args.continuous_actions

    printGreen("\nAgent = {} \n".format(args.algo))

    env_kwargs["action_repeat"] = args.action_repeat
    # Random init position for button
    env_kwargs["random_target"] = args.random_target
    # Allow up action
    # env_kwargs["force_down"] = False

    # allow multi-view
    env_kwargs['multi_view'] = args.srl_model == "multi_view_srl"
    parser = algo.customArguments(parser)
    args = parser.parse_args()

    args, env_kwargs = configureEnvAndLogFolder(args, env_kwargs, all_models)
    args_dict = filterJSONSerializableObjects(vars(args))
    # Save args
    with open(LOG_DIR + "args.json", "w") as f:
        json.dump(args_dict, f)

    env_class = registered_env[args.env][0]
    # env default kwargs
    default_env_kwargs = {k: v.default
                          for k, v in inspect.signature(env_class.__init__).parameters.items()
                          if v is not None}

    globals_env_param = sys.modules[env_class.__module__].getGlobals()

    super_class = registered_env[args.env][1]
    # reccursive search through all the super classes of the asked environment, in order to get all the arguments.
    rec_super_class_lookup = {dict_class: dict_super_class for _, (dict_class, dict_super_class, _, _) in
                              registered_env.items()}
    while super_class != SRLGymEnv:
        assert super_class in rec_super_class_lookup, "Error: could not find super class of {}".format(super_class) + \
                                                      ", are you sure \"registered_env\" is correctly defined?"
        super_env_kwargs = {k: v.default
                            for k, v in inspect.signature(super_class.__init__).parameters.items()
                            if v is not None}
        default_env_kwargs = {**super_env_kwargs, **default_env_kwargs}

        globals_env_param = {**sys.modules[super_class.__module__].getGlobals(), **globals_env_param}

        super_class = rec_super_class_lookup[super_class]

    # Print Variables
    printYellow("Arguments:")
    pprint(args_dict)
    printYellow("Env Globals:")
    pprint(filterJSONSerializableObjects({**globals_env_param, **default_env_kwargs, **env_kwargs}))
    # Save env params
    saveEnvParams(globals_env_param, {**default_env_kwargs, **env_kwargs})
    # Seed tensorflow, python and numpy random generator
    set_global_seeds(args.seed)
    # Augment the number of timesteps (when using mutliprocessing this number is not reached)
    args.num_timesteps = int(1.1 * args.num_timesteps)
    # Get the hyperparameter, if given (Hyperband)
    hyperparams = {param.split(":")[0]: param.split(":")[1] for param in args.hyperparam}
    hyperparams = algo.parserHyperParam(hyperparams)
    
    if args.load_rl_model_path is not None:
        #use a small learning rate
        print("use a small learning rate: {:f}".format(1.0e-4))
        hyperparams["learning_rate"] = lambda f: f * 1.0e-4
        
    # Train the agent

    if args.load_rl_model_path is not None:
        algo.setLoadPath(args.load_rl_model_path)
    algo.train(args, callback, env_kwargs=env_kwargs, train_kwargs=hyperparams)