Пример #1
0
def main():
    parser = argparse.ArgumentParser(description="Hyperparameter search for implemented RL models")
    parser.add_argument('--optimizer', default='hyperband', choices=['hyperband', 'hyperopt'], type=str,
                        help='The hyperparameter optimizer to choose from')
    parser.add_argument('--algo', default='ppo2', choices=list(registered_rl.keys()), help='OpenAI baseline to use',
                        type=str)
    parser.add_argument('--env', type=str, help='environment ID', default='KukaButtonGymEnv-v0',
                        choices=list(registered_env.keys()))
    parser.add_argument('--seed', type=int, default=0, help='random seed (default: 0)')
    parser.add_argument('--srl-model', type=str, default='raw_pixels', choices=list(registered_srl.keys()),
                        help='SRL model to use')
    parser.add_argument('--num-timesteps', type=int, default=1e6, help='number of timesteps the baseline should run')
    parser.add_argument('-v', '--verbose', action='store_true', default=False, help='Display baseline STDOUT')
    parser.add_argument('--max-eval', type=int, default=100, help='Number of evalutation to try for hyperopt')

    args, train_args = parser.parse_known_args()
    args.log_dir = "logs/_{}_search/".format(args.optimizer)

    train_args.extend(['--srl-model', args.srl_model, '--seed', str(args.seed), '--algo', args.algo, '--env', args.env,
                       '--log-dir', args.log_dir, '--no-vis'])

    # verify the algorithm has defined it, and that it returnes an expected value
    try:
        opt_param = registered_rl[args.algo][0].getOptParam()
        assert opt_param is not None
    except AttributeError or AssertionError:
        raise AssertionError("Error: {} algo does not support hyperparameter search.".format(args.algo))

    if args.optimizer == "hyperband":
        opt = Hyperband(opt_param, makeRlTrainingFunction(args, train_args), seed=args.seed,
                        max_iter=args.num_timesteps // ITERATION_SCALE)
    elif args.optimizer == "hyperopt":
        opt = Hyperopt(opt_param, makeRlTrainingFunction(args, train_args), seed=args.seed, num_eval=args.max_eval)
    else:
        raise ValueError("Error: optimizer {} was defined but not implemented, Halting.".format(args.optimizer))

    t_start = time.time()
    opt.run()
    all_params, loss = zip(*opt.history)
    idx = np.argmin(loss)
    opt_params, nb_iter = all_params[idx]
    reward = loss[idx]
    print('\ntime to run : {}s'.format(int(time.time() - t_start)))
    print('Total nb. evaluations : {}'.format(len(all_params)))
    if nb_iter is not None:
        print('Best nb. of iterations : {}'.format(int(nb_iter)))
    print('Best params : ')
    pprint.pprint(opt_params)
    print('Best reward : {:.3f}'.format(-reward))

    param_dict, timesteps = zip(*all_params)
    output = pd.DataFrame(list(param_dict))
    # make sure we returned a timestep value to log, otherwise ignore
    if not any([el is None for el in timesteps]):
        output["timesteps"] = np.array(np.maximum(MIN_ITERATION, np.array(timesteps) * ITERATION_SCALE).astype(int))
    output["reward"] = -np.array(loss)
    output.to_csv("logs/{}_{}_{}_{}_seed{}_numtimestep{}.csv"
                  .format(args.optimizer, args.algo, args.env, args.srl_model, args.seed, args.num_timesteps))
def loadConfigAndSetup(load_args):
    """
    Get the training config and setup the parameters
    :param load_args: (Arguments)
    :return: (dict, str, str, str, dict)
    """
    algo_name = ""
    for algo in list(registered_rl.keys()):
        if algo in load_args.log_dir:
            algo_name = algo
            break
    algo_class, algo_type, _ = registered_rl[algo_name]
    if algo_type == AlgoType.OTHER:
        raise ValueError(algo_name + " is not supported for replay")
    printGreen("\n" + algo_name + "\n")

    load_path = "{}/{}_model.pkl".format(load_args.log_dir, algo_name)

    env_globals = json.load(open(load_args.log_dir + "env_globals.json", 'r'))
    train_args = json.load(open(load_args.log_dir + "args.json", 'r'))

    env_kwargs = {
        "renders": load_args.render,
        "shape_reward": load_args.shape_reward,  # Reward sparse or shaped
        "action_joints": train_args["action_joints"],
        "is_discrete": not train_args["continuous_actions"],
        "random_target": train_args.get('random_target', False),
        "srl_model": train_args["srl_model"]
    }

    # load it, if it was defined
    if "action_repeat" in env_globals:
        env_kwargs["action_repeat"] = env_globals['action_repeat']

    # Remove up action
    if train_args["env"] == "Kuka2ButtonGymEnv-v0":
        env_kwargs["force_down"] = env_globals.get('force_down', True)
    else:
        env_kwargs["force_down"] = env_globals.get('force_down', False)

    srl_model_path = None
    if train_args["srl_model"] != "raw_pixels":
        train_args["policy"] = "mlp"
        path = env_globals.get('srl_model_path')

        if path is not None:
            env_kwargs["use_srl"] = True
            # Check that the srl saved model exists on the disk
            assert os.path.isfile(
                env_globals['srl_model_path']), "{} does not exist".format(
                    env_globals['srl_model_path'])
            srl_model_path = env_globals['srl_model_path']
            env_kwargs["srl_model_path"] = srl_model_path

    return train_args, load_path, algo_name, algo_class, srl_model_path, env_kwargs
Пример #3
0
def main():
    # Global variables for callback
    global ENV_NAME, ALGO, ALGO_NAME, LOG_INTERVAL, VISDOM_PORT, viz
    global SAVE_INTERVAL, EPISODE_WINDOW, MIN_EPISODES_BEFORE_SAVE
    parser = argparse.ArgumentParser(
        description="Train script for RL algorithms")
    parser.add_argument('--algo',
                        default='ppo2',
                        choices=list(registered_rl.keys()),
                        help='RL algo to use',
                        type=str)
    parser.add_argument('--env',
                        type=str,
                        help='environment ID',
                        default='KukaButtonGymEnv-v0',
                        choices=list(registered_env.keys()))
    parser.add_argument('--seed',
                        type=int,
                        default=0,
                        help='random seed (default: 0)')
    parser.add_argument(
        '--episode-window',
        type=int,
        default=40,
        help='Episode window for moving average plot (default: 40)')
    parser.add_argument(
        '--log-dir',
        default='/tmp/gym/',
        type=str,
        help='directory to save agent logs and model (default: /tmp/gym)')
    parser.add_argument('--num-timesteps', type=int, default=int(1e6))
    parser.add_argument('--srl-model',
                        type=str,
                        default='raw_pixels',
                        choices=list(registered_srl.keys()),
                        help='SRL model to use')
    parser.add_argument('--num-stack',
                        type=int,
                        default=1,
                        help='number of frames to stack (default: 1)')
    parser.add_argument(
        '--action-repeat',
        type=int,
        default=1,
        help='number of times an action will be repeated (default: 1)')
    parser.add_argument('--port',
                        type=int,
                        default=8097,
                        help='visdom server port (default: 8097)')
    parser.add_argument('--no-vis',
                        action='store_true',
                        default=False,
                        help='disables visdom visualization')
    parser.add_argument(
        '--shape-reward',
        action='store_true',
        default=False,
        help='Shape the reward (reward = - distance) instead of a sparse reward'
    )
    parser.add_argument('-c',
                        '--continuous-actions',
                        action='store_true',
                        default=False)
    parser.add_argument(
        '-joints',
        '--action-joints',
        action='store_true',
        default=False,
        help=
        'set actions to the joints of the arm directly, instead of inverse kinematics'
    )
    parser.add_argument('-r',
                        '--random-target',
                        action='store_true',
                        default=False,
                        help='Set the button to a random position')
    parser.add_argument(
        '--srl-config-file',
        type=str,
        default="config/srl_models.yaml",
        help='Set the location of the SRL model path configuration.')
    parser.add_argument('--hyperparam', type=str, nargs='+', default=[])
    parser.add_argument('--min-episodes-save',
                        type=int,
                        default=100,
                        help="Min number of episodes before saving best model")
    parser.add_argument(
        '--latest',
        action='store_true',
        default=False,
        help=
        'load the latest learned model (location:srl_zoo/logs/DatasetName/)')
    parser.add_argument(
        '--load-rl-model-path',
        type=str,
        default=None,
        help="load the trained RL model, should be with the same algorithm type"
    )
    parser.add_argument(
        '-sc',
        '--simple-continual',
        action='store_true',
        default=False,
        help=
        'Simple red square target for task 1 of continual learning scenario. '
        + 'The task is: robot should reach the target.')
    parser.add_argument(
        '-cc',
        '--circular-continual',
        action='store_true',
        default=False,
        help='Blue square target for task 2 of continual learning scenario. ' +
        'The task is: robot should turn in circle around the target.')
    parser.add_argument(
        '-sqc',
        '--square-continual',
        action='store_true',
        default=False,
        help='Green square target for task 3 of continual learning scenario. '
        + 'The task is: robot should turn in square around the target.')
    parser.add_argument(
        '-ec',
        '--eight-continual',
        action='store_true',
        default=False,
        help='Green square target for task 4 of continual learning scenario. '
        +
        'The task is: robot should do the eigth with the target as center of the shape.'
    )
    parser.add_argument('--teacher-data-folder',
                        type=str,
                        default="",
                        help='Dataset folder of the teacher(s) policy(ies)',
                        required=False)
    parser.add_argument(
        '--epochs-distillation',
        type=int,
        default=30,
        metavar='N',
        help='number of epochs to train for distillation(default: 30)')
    parser.add_argument(
        '--distillation-training-set-size',
        type=int,
        default=-1,
        help='Limit size (number of samples) of the training set (default: -1)'
    )
    parser.add_argument(
        '--perform-cross-evaluation-cc',
        action='store_true',
        default=False,
        help='A cross evaluation from the latest stored model to all tasks')
    parser.add_argument(
        '--eval-episode-window',
        type=int,
        default=400,
        metavar='N',
        help=
        'Episode window for saving each policy checkpoint for future distillation(default: 100)'
    )
    parser.add_argument(
        '--new-lr',
        type=float,
        default=1.e-4,
        help="New learning rate ratio to train a pretrained agent")
    parser.add_argument('--img-shape',
                        type=str,
                        default="(3,64,64)",
                        help="Image shape of environment.")
    parser.add_argument(
        "--gpu-num",
        help="Choose the number of GPU (CUDA_VISIBLE_DEVICES).",
        type=str,
        default="1",
        choices=["0", "1", "2", "3", "5", "6", "7", "8"])
    parser.add_argument("--srl-model-path",
                        help="SRL model weights path",
                        type=str,
                        default=None)
    parser.add_argument(
        "--relative-pos",
        action='store_true',
        default=False,
        help="For 'ground_truth': use relative position or not.")
    # Ignore unknown args for now
    args, unknown = parser.parse_known_args()
    # os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num
    env_kwargs = {}
    if args.img_shape is None:
        img_shape = None  #(3,224,224)
    else:
        img_shape = tuple(map(int, args.img_shape[1:-1].split(",")))
    env_kwargs['img_shape'] = img_shape
    # LOAD SRL models list
    assert os.path.exists(args.srl_config_file), \
        "Error: cannot load \"--srl-config-file {}\", file not found!".format(args.srl_config_file)
    with open(args.srl_config_file, 'rb') as f:
        all_models = yaml.load(f)
    # Sanity check
    assert args.episode_window >= 1, "Error: --episode_window cannot be less than 1"
    assert args.num_timesteps >= 1, "Error: --num-timesteps cannot be less than 1"
    assert args.num_stack >= 1, "Error: --num-stack cannot be less than 1"
    assert args.action_repeat >= 1, "Error: --action-repeat cannot be less than 1"
    assert 0 <= args.port < 65535, "Error: invalid visdom port number {}, ".format(args.port) + \
                                   "port number must be an unsigned 16bit number [0,65535]."
    assert registered_srl[args.srl_model][0] == SRLType.ENVIRONMENT or args.env in all_models, \
        "Error: the environment {} has no srl_model defined in 'srl_models.yaml'. Cannot continue.".format(args.env)
    # check that all the SRL_model can be run on the environment
    if registered_srl[args.srl_model][1] is not None:
        found = False
        for compatible_class in registered_srl[args.srl_model][1]:
            if issubclass(compatible_class, registered_env[args.env][0]):
                found = True
                break
        assert found, "Error: srl_model {}, is not compatible with the {} environment.".format(
            args.srl_model, args.env)

    assert not(sum([args.simple_continual, args.circular_continual, args.square_continual, args.eight_continual]) \
           > 1 and args.env == "OmnirobotEnv-v0"), \
        "For continual SRL and RL, please provide only one scenario at the time and use OmnirobotEnv-v0 environment !"

    assert not(args.algo == "distillation" and (args.teacher_data_folder == '' or args.continuous_actions is True)), \
        "For performing policy distillation, make sure use specify a valid teacher dataset and discrete actions !"

    ENV_NAME = args.env
    ALGO_NAME = args.algo
    VISDOM_PORT = args.port
    EPISODE_WINDOW = args.episode_window
    MIN_EPISODES_BEFORE_SAVE = args.min_episodes_save
    CROSS_EVAL = args.perform_cross_evaluation_cc
    EPISODE_WINDOW_DISTILLATION_WIN = args.eval_episode_window
    NEW_LR = args.new_lr
    print("EPISODE_WINDOW_DISTILLATION_WIN: ", EPISODE_WINDOW_DISTILLATION_WIN)

    if args.no_vis:
        viz = False

    algo_class, algo_type, action_type = registered_rl[args.algo]
    algo = algo_class()
    ALGO = algo

    # if callback frequency needs to be changed
    LOG_INTERVAL = algo.LOG_INTERVAL
    SAVE_INTERVAL = algo.SAVE_INTERVAL

    if not args.continuous_actions and ActionType.DISCRETE not in action_type:
        raise ValueError(
            args.algo +
            " does not support discrete actions, please use the '--continuous-actions' "
            + "(or '-c') flag.")
    if args.continuous_actions and ActionType.CONTINUOUS not in action_type:
        raise ValueError(
            args.algo +
            " does not support continuous actions, please remove the '--continuous-actions' "
            + "(or '-c') flag.")

    env_kwargs["is_discrete"] = not args.continuous_actions

    printGreen("\nAgent = {} \n".format(args.algo))

    env_kwargs["action_repeat"] = args.action_repeat
    # Random init position for button
    env_kwargs["random_target"] = args.random_target

    # If in simple continual scenario, then the target should be initialized randomly.
    if args.simple_continual is True:
        env_kwargs["random_target"] = True

    # Allow up action
    # env_kwargs["force_down"] = False

    # allow multi-view
    env_kwargs['multi_view'] = args.srl_model == "multi_view_srl"
    parser = algo.customArguments(parser)
    args = parser.parse_args()

    args, env_kwargs = configureEnvAndLogFolder(args, env_kwargs, all_models)
    args_dict = filterJSONSerializableObjects(vars(args))
    # Save args
    with open(LOG_DIR + "args.json", "w") as f:
        json.dump(args_dict, f)

    env_class = registered_env[args.env][0]
    # env default kwargs
    default_env_kwargs = {
        k: v.default
        for k, v in inspect.signature(env_class.__init__).parameters.items()
        if v is not None
    }

    globals_env_param = sys.modules[env_class.__module__].getGlobals()
    ### HACK way to reset image shape !!
    globals_env_param['RENDER_HEIGHT'] = img_shape[1]
    globals_env_param['RENDER_WIDTH'] = img_shape[2]
    globals_env_param['RELATIVE_POS'] = args.relative_pos

    super_class = registered_env[args.env][1]
    # reccursive search through all the super classes of the asked environment, in order to get all the arguments.
    rec_super_class_lookup = {
        dict_class: dict_super_class
        for _, (dict_class, dict_super_class, _, _) in registered_env.items()
    }
    while super_class != SRLGymEnv:
        assert super_class in rec_super_class_lookup, "Error: could not find super class of {}".format(super_class) + \
                                                      ", are you sure \"registered_env\" is correctly defined?"
        super_env_kwargs = {
            k: v.default
            for k, v in inspect.signature(
                super_class.__init__).parameters.items() if v is not None
        }
        default_env_kwargs = {**super_env_kwargs, **default_env_kwargs}

        globals_env_param = {
            **sys.modules[super_class.__module__].getGlobals(),
            **globals_env_param
        }

        super_class = rec_super_class_lookup[super_class]

    # Print Variables
    printYellow("Arguments:")
    pprint(args_dict)
    printYellow("Env Globals:")
    pprint(
        filterJSONSerializableObjects({
            **globals_env_param,
            **default_env_kwargs,
            **env_kwargs
        }))
    # Save env params
    saveEnvParams(globals_env_param, {**default_env_kwargs, **env_kwargs})
    # Seed tensorflow, python and numpy random generator
    set_global_seeds(args.seed)
    # Augment the number of timesteps (when using mutliprocessing this number is not reached)
    args.num_timesteps = int(1.1 * args.num_timesteps)
    # Get the hyperparameter, if given (Hyperband)
    hyperparams = {
        param.split(":")[0]: param.split(":")[1]
        for param in args.hyperparam
    }
    hyperparams = algo.parserHyperParam(hyperparams)

    if args.load_rl_model_path is not None:
        #use a small learning rate
        print("use a small learning rate: {:f}".format(1.0e-4))
        hyperparams["learning_rate"] = lambda f: f * 1.0e-4

    # Train the agent
    if args.load_rl_model_path is not None:
        algo.setLoadPath(args.load_rl_model_path)
    algo.train(args, callback, env_kwargs=env_kwargs, train_kwargs=hyperparams)
Пример #4
0
def main():
    parser = argparse.ArgumentParser(
        description="OpenAI RL Baselines Benchmark",
        epilog=
        'After the arguments are parsed, the rest are assumed to be arguments for'
        + ' rl_baselines.train')
    parser.add_argument('--algo',
                        type=str,
                        default='ppo2',
                        help='OpenAI baseline to use',
                        choices=list(registered_rl.keys()))
    parser.add_argument('--env',
                        type=str,
                        nargs='+',
                        default=["KukaButtonGymEnv-v0"],
                        help='environment ID(s)',
                        choices=list(registered_env.keys()))
    parser.add_argument('--srl-model',
                        type=str,
                        nargs='+',
                        default=["raw_pixels"],
                        help='SRL model(s) to use',
                        choices=list(registered_srl.keys()))
    parser.add_argument('--num-timesteps',
                        type=int,
                        default=1e6,
                        help='number of timesteps the baseline should run')
    parser.add_argument('-v',
                        '--verbose',
                        action='store_true',
                        default=False,
                        help='Display baseline STDOUT')
    parser.add_argument(
        '--num-iteration',
        type=int,
        default=15,
        help=
        'number of time each algorithm should be run for each unique combination of environment '
        + ' and srl-model.')
    parser.add_argument(
        '--seed',
        type=int,
        default=0,
        help=
        'initial seed for each unique combination of environment and srl-model.'
    )
    parser.add_argument(
        '--srl-config-file',
        type=str,
        default="config/srl_models.yaml",
        help='Set the location of the SRL model path configuration.')

    # returns the parsed arguments, and the rest are assumed to be arguments for rl_baselines.train
    args, train_args = parser.parse_known_args()

    # Sanity check
    assert args.num_timesteps >= 1, "Error: --num-timesteps cannot be less than 1"
    assert args.num_iteration >= 1, "Error: --num-iteration cannot be less than 1"

    # Removing duplicates and sort
    srl_models = list(set(args.srl_model))
    envs = list(set(args.env))
    srl_models.sort()
    envs.sort()

    # LOAD SRL models list
    assert os.path.exists(args.srl_config_file), \
        "Error: cannot load \"--srl-config-file {}\", file not found!".format(args.srl_config_file)
    with open(args.srl_config_file, 'rb') as f:
        all_models = yaml.load(f)

    # Checking definition and presence of all requested srl_models
    valid = True
    for env in envs:
        # validated the env definition
        if env not in all_models:
            printRed(
                "Error: 'srl_models.yaml' missing definition for environment {}"
                .format(env))
            valid = False
            continue  # skip to the next env, this one is not valid

        # checking log_folder for current env
        missing_log = "log_folder" not in all_models[env]
        if missing_log:
            printRed(
                "Error: 'srl_models.yaml' missing definition for log_folder in environment {}"
                .format(env))
            valid = False

        # validate each model for the current env definition
        for model in srl_models:
            if registered_srl[model][0] == SRLType.ENVIRONMENT:
                continue  # not an srl model, skip to the next model
            elif model not in all_models[env]:
                printRed(
                    "Error: 'srl_models.yaml' missing srl_model {} for environment {}"
                    .format(model, env))
                valid = False
            elif (not missing_log) and (
                    not os.path.exists(all_models[env]["log_folder"] +
                                       all_models[env][model])):
                # checking presence of srl_model path, if and only if log_folder exists
                printRed(
                    "Error: srl_model {} for environment {} was defined in ".
                    format(model, env) +
                    "'srl_models.yaml', however the file {} it was tagetting does not exist."
                    .format(all_models[env]["log_folder"] +
                            all_models[env][model]))
                valid = False

    assert valid, "Errors occured due to malformed 'srl_models.yaml', cannot continue."

    # check that all the SRL_models can be run on all the environments
    valid = True
    for env in envs:
        for model in srl_models:
            if registered_srl[model][1] is not None:
                found = False
                for compatible_class in registered_srl[model][1]:
                    if issubclass(compatible_class, registered_env[env][0]):
                        found = True
                        break
                if not found:
                    valid = False
                    printRed(
                        "Error: srl_model {}, is not compatible with the {} environment."
                        .format(model, env))
    assert valid, "Errors occured due to an incompatible combination of srl_model and environment, cannot continue."

    # the seeds used in training the baseline.
    seeds = list(np.arange(args.num_iteration) + args.seed)

    if args.verbose:
        # None here means stdout of terminal for subprocess.call
        stdout = None
    else:
        stdout = open(os.devnull, 'w')

    printGreen("\nRunning {} benchmarks {} times...".format(
        args.algo, args.num_iteration))
    print("\nSRL-Models:\t{}".format(srl_models))
    print("environments:\t{}".format(envs))
    print("verbose:\t{}".format(args.verbose))
    print("timesteps:\t{}".format(args.num_timesteps))
    for model in srl_models:
        for env in envs:
            for i in range(args.num_iteration):

                printGreen(
                    "\nIteration_num={} (seed: {}), Environment='{}', SRL-Model='{}'"
                    .format(i, seeds[i], env, model))

                # redefine the parsed args for rl_baselines.train
                loop_args = [
                    '--srl-model', model, '--seed',
                    str(seeds[i]), '--algo', args.algo, '--env', env,
                    '--num-timesteps',
                    str(int(args.num_timesteps)), '--srl-config-file',
                    args.srl_config_file
                ]

                ok = subprocess.call(['python', '-m', 'rl_baselines.train'] +
                                     train_args + loop_args,
                                     stdout=stdout)

                if ok != 0:
                    # throw the error down to the terminal
                    raise ChildProcessError(
                        "An error occured, error code: {}".format(ok))
Пример #5
0
def main():
    # Global variables for callback
    global ENV_NAME, ALGO, ALGO_NAME, LOG_INTERVAL, VISDOM_PORT, viz
    global SAVE_INTERVAL, EPISODE_WINDOW, MIN_EPISODES_BEFORE_SAVE
    parser = argparse.ArgumentParser(description="Train script for RL algorithms")
    parser.add_argument('--algo', default='ppo2', choices=list(registered_rl.keys()), help='RL algo to use',
                        type=str)
    parser.add_argument('--env', type=str, help='environment ID', default='KukaButtonGymEnv-v0',
                        choices=list(registered_env.keys()))
    parser.add_argument('--seed', type=int, default=0, help='random seed (default: 0)')
    parser.add_argument('--episode_window', type=int, default=40,
                        help='Episode window for moving average plot (default: 40)')
    parser.add_argument('--log-dir', default='/tmp/gym/', type=str,
                        help='directory to save agent logs and model (default: /tmp/gym)')
    parser.add_argument('--num-timesteps', type=int, default=int(1e6))
    parser.add_argument('--srl-model', type=str, default='raw_pixels', choices=list(registered_srl.keys()),
                        help='SRL model to use')
    parser.add_argument('--num-stack', type=int, default=1, help='number of frames to stack (default: 1)')
    parser.add_argument('--action-repeat', type=int, default=1,
                        help='number of times an action will be repeated (default: 1)')
    parser.add_argument('--port', type=int, default=8097, help='visdom server port (default: 8097)')
    parser.add_argument('--no-vis', action='store_true', default=False, help='disables visdom visualization')
    parser.add_argument('--shape-reward', action='store_true', default=False,
                        help='Shape the reward (reward = - distance) instead of a sparse reward')
    parser.add_argument('-c', '--continuous-actions', action='store_true', default=False)
    parser.add_argument('-joints', '--action-joints', action='store_true', default=False,
                        help='set actions to the joints of the arm directly, instead of inverse kinematics')
    parser.add_argument('-r', '--random-target', action='store_true', default=False,
                        help='Set the button to a random position')
    parser.add_argument('--srl-config-file', type=str, default="config/srl_models.yaml",
                        help='Set the location of the SRL model path configuration.')
    parser.add_argument('--hyperparam', type=str, nargs='+', default=[])
    parser.add_argument('--min-episodes-save', type=int, default=100,
                        help="Min number of episodes before saving best model")
    parser.add_argument('--latest', action='store_true', default=False,
                        help='load the latest learned model (location:srl_zoo/logs/DatasetName/)')
    parser.add_argument('--load-rl-model-path', type=str, default=None,
                        help="load the trained RL model, should be with the same algorithm type")
    
    # Ignore unknown args for now
    args, unknown = parser.parse_known_args()
    env_kwargs = {}

    # LOAD SRL models list
    assert os.path.exists(args.srl_config_file), \
        "Error: cannot load \"--srl-config-file {}\", file not found!".format(args.srl_config_file)
    with open(args.srl_config_file, 'rb') as f:
        all_models = yaml.load(f)

    # Sanity check
    assert args.episode_window >= 1, "Error: --episode_window cannot be less than 1"
    assert args.num_timesteps >= 1, "Error: --num-timesteps cannot be less than 1"
    assert args.num_stack >= 1, "Error: --num-stack cannot be less than 1"
    assert args.action_repeat >= 1, "Error: --action-repeat cannot be less than 1"
    assert 0 <= args.port < 65535, "Error: invalid visdom port number {}, ".format(args.port) + \
                                   "port number must be an unsigned 16bit number [0,65535]."
    assert registered_srl[args.srl_model][0] == SRLType.ENVIRONMENT or args.env in all_models, \
        "Error: the environment {} has no srl_model defined in 'srl_models.yaml'. Cannot continue.".format(args.env)
    # check that all the SRL_model can be run on the environment
    if registered_srl[args.srl_model][1] is not None:
        found = False
        for compatible_class in registered_srl[args.srl_model][1]:
            if issubclass(compatible_class, registered_env[args.env][0]):
                found = True
                break
        assert found, "Error: srl_model {}, is not compatible with the {} environment.".format(args.srl_model, args.env)

    ENV_NAME = args.env
    ALGO_NAME = args.algo
    VISDOM_PORT = args.port
    EPISODE_WINDOW = args.episode_window
    MIN_EPISODES_BEFORE_SAVE = args.min_episodes_save

    if args.no_vis:
        viz = False

    algo_class, algo_type, action_type = registered_rl[args.algo]
    algo = algo_class()
    ALGO = algo
    

    # if callback frequency needs to be changed
    LOG_INTERVAL = algo.LOG_INTERVAL
    SAVE_INTERVAL = algo.SAVE_INTERVAL

    if not args.continuous_actions and ActionType.DISCRETE not in action_type:
        raise ValueError(args.algo + " does not support discrete actions, please use the '--continuous-actions' " +
                         "(or '-c') flag.")
    if args.continuous_actions and ActionType.CONTINUOUS not in action_type:
        raise ValueError(args.algo + " does not support continuous actions, please remove the '--continuous-actions' " +
                         "(or '-c') flag.")

    env_kwargs["is_discrete"] = not args.continuous_actions

    printGreen("\nAgent = {} \n".format(args.algo))

    env_kwargs["action_repeat"] = args.action_repeat
    # Random init position for button
    env_kwargs["random_target"] = args.random_target
    # Allow up action
    # env_kwargs["force_down"] = False

    # allow multi-view
    env_kwargs['multi_view'] = args.srl_model == "multi_view_srl"
    parser = algo.customArguments(parser)
    args = parser.parse_args()

    args, env_kwargs = configureEnvAndLogFolder(args, env_kwargs, all_models)
    args_dict = filterJSONSerializableObjects(vars(args))
    # Save args
    with open(LOG_DIR + "args.json", "w") as f:
        json.dump(args_dict, f)

    env_class = registered_env[args.env][0]
    # env default kwargs
    default_env_kwargs = {k: v.default
                          for k, v in inspect.signature(env_class.__init__).parameters.items()
                          if v is not None}

    globals_env_param = sys.modules[env_class.__module__].getGlobals()

    super_class = registered_env[args.env][1]
    # reccursive search through all the super classes of the asked environment, in order to get all the arguments.
    rec_super_class_lookup = {dict_class: dict_super_class for _, (dict_class, dict_super_class, _, _) in
                              registered_env.items()}
    while super_class != SRLGymEnv:
        assert super_class in rec_super_class_lookup, "Error: could not find super class of {}".format(super_class) + \
                                                      ", are you sure \"registered_env\" is correctly defined?"
        super_env_kwargs = {k: v.default
                            for k, v in inspect.signature(super_class.__init__).parameters.items()
                            if v is not None}
        default_env_kwargs = {**super_env_kwargs, **default_env_kwargs}

        globals_env_param = {**sys.modules[super_class.__module__].getGlobals(), **globals_env_param}

        super_class = rec_super_class_lookup[super_class]

    # Print Variables
    printYellow("Arguments:")
    pprint(args_dict)
    printYellow("Env Globals:")
    pprint(filterJSONSerializableObjects({**globals_env_param, **default_env_kwargs, **env_kwargs}))
    # Save env params
    saveEnvParams(globals_env_param, {**default_env_kwargs, **env_kwargs})
    # Seed tensorflow, python and numpy random generator
    set_global_seeds(args.seed)
    # Augment the number of timesteps (when using mutliprocessing this number is not reached)
    args.num_timesteps = int(1.1 * args.num_timesteps)
    # Get the hyperparameter, if given (Hyperband)
    hyperparams = {param.split(":")[0]: param.split(":")[1] for param in args.hyperparam}
    hyperparams = algo.parserHyperParam(hyperparams)
    
    if args.load_rl_model_path is not None:
        #use a small learning rate
        print("use a small learning rate: {:f}".format(1.0e-4))
        hyperparams["learning_rate"] = lambda f: f * 1.0e-4
        
    # Train the agent

    if args.load_rl_model_path is not None:
        algo.setLoadPath(args.load_rl_model_path)
    algo.train(args, callback, env_kwargs=env_kwargs, train_kwargs=hyperparams)
Пример #6
0
def loadConfigAndSetup(load_args):
    """
    Get the training config and setup the parameters
    :param load_args: (Arguments)
    :return: (dict, str, str, str, dict)
    """
    algo_name = ""
    for algo in list(registered_rl.keys()):
        if algo in load_args.log_dir:
            algo_name = algo
            break
    algo_class, algo_type, _ = registered_rl[algo_name]
    if algo_type == AlgoType.OTHER:
        raise ValueError(algo_name + " is not supported for replay")
    printGreen("\n" + algo_name + "\n")

    try:  # If args contains episode information, this is for student_evaluation (distillation)
        if not load_args.episode == -1:
            load_path = "{}/{}_{}_model.pkl".format(load_args.log_dir, algo_name, load_args.episode,)
        else:
            load_path = "{}/{}_model.pkl".format(load_args.log_dir, algo_name)
    except:
        printYellow(
            "No episode of checkpoint specified, go for the default policy model: {}_model.pkl".format(algo_name))
        if load_args.log_dir[-3:] != 'pkl':
            load_path = "{}/{}_model.pkl".format(load_args.log_dir, algo_name)
        else:
            load_path = load_args.log_dir
            load_args.log_dir = os.path.dirname(load_path)+'/'

    env_globals = json.load(open(load_args.log_dir + "env_globals.json", 'r'))
    train_args = json.load(open(load_args.log_dir + "args.json", 'r'))

    env_kwargs = {
        "renders": load_args.render,
        "shape_reward": load_args.shape_reward,  # Reward sparse or shaped
        "action_joints": train_args["action_joints"],
        "is_discrete": not train_args["continuous_actions"],
        "random_target": train_args.get('random_target', False),
        "srl_model": train_args["srl_model"]
    }

    # load it, if it was defined
    if "action_repeat" in env_globals:
        env_kwargs["action_repeat"] = env_globals['action_repeat']

    # Remove up action
    if train_args["env"] == "Kuka2ButtonGymEnv-v0":
        env_kwargs["force_down"] = env_globals.get('force_down', True)
    else:
        env_kwargs["force_down"] = env_globals.get('force_down', False)

    if train_args["env"] == "OmnirobotEnv-v0":
        env_kwargs["simple_continual_target"] = env_globals.get("simple_continual_target", False)
        env_kwargs["circular_continual_move"] = env_globals.get("circular_continual_move", False)
        env_kwargs["square_continual_move"] = env_globals.get("square_continual_move", False)
        env_kwargs["eight_continual_move"] = env_globals.get("eight_continual_move", False)

        # If overriding the environment for specific Continual Learning tasks
        if sum([load_args.simple_continual, load_args.circular_continual, load_args.square_continual]) >= 1:
            env_kwargs["simple_continual_target"] = load_args.simple_continual
            env_kwargs["circular_continual_move"] = load_args.circular_continual
            env_kwargs["square_continual_move"] = load_args.square_continual
            env_kwargs["random_target"] = not (load_args.circular_continual or load_args.square_continual)

    srl_model_path = None
    if train_args["srl_model"] != "raw_pixels":
        train_args["policy"] = "mlp"
        path = env_globals.get('srl_model_path')

        if path is not None:
            env_kwargs["use_srl"] = True
            # Check that the srl saved model exists on the disk
            assert os.path.isfile(env_globals['srl_model_path']), \
                "{} does not exist".format(env_globals['srl_model_path'])
            srl_model_path = env_globals['srl_model_path']
            env_kwargs["srl_model_path"] = srl_model_path

    return train_args, load_path, algo_name, algo_class, srl_model_path, env_kwargs
def loadConfigAndSetup(log_dir):
    """
    load training variable from a pre-trained model
    :param log_dir: the path where the model is located,
    example: logs/sc2cc/OmnirobotEnv-v0/srl_combination/ppo2/19-05-07_11h32_39
    :return: train_args, algo_name, algo_class(stable_baselines.PPO2), srl_model_path, env_kwargs
    """
    algo_name = ""
    for algo in list(registered_rl.keys()):
        if algo in log_dir:
            algo_name = algo
            break
    algo_class, algo_type, _ = registered_rl[algo_name]
    if algo_type == AlgoType.OTHER:
        raise ValueError(algo_name + " is not supported for evaluation")

    env_globals = json.load(open(log_dir + "env_globals.json", 'r'))
    train_args = json.load(open(log_dir + "args.json", 'r'))
    env_kwargs = {
        "renders": False,
        "shape_reward":
        False,  #TODO, since we dont use simple target, we should elimanate this choice?
        "action_joints": train_args["action_joints"],
        "is_discrete": not train_args["continuous_actions"],
        "random_target": train_args.get('random_target', False),
        "srl_model": train_args["srl_model"]
    }

    # load it, if it was defined
    if "action_repeat" in env_globals:
        env_kwargs["action_repeat"] = env_globals['action_repeat']

    # Remove up action
    if train_args["env"] == "Kuka2ButtonGymEnv-v0":
        env_kwargs["force_down"] = env_globals.get('force_down', True)
    else:
        env_kwargs["force_down"] = env_globals.get('force_down', False)

    if train_args["env"] == "OmnirobotEnv-v0":
        env_kwargs["simple_continual_target"] = env_globals.get(
            "simple_continual_target", False)
        env_kwargs["circular_continual_move"] = env_globals.get(
            "circular_continual_move", False)
        env_kwargs["square_continual_move"] = env_globals.get(
            "square_continual_move", False)
        env_kwargs["eight_continual_move"] = env_globals.get(
            "eight_continual_move", False)

    srl_model_path = None
    if train_args["srl_model"] != "raw_pixels":
        train_args["policy"] = "mlp"
        path = env_globals.get('srl_model_path')

        if path is not None:
            env_kwargs["use_srl"] = True
            # Check that the srl saved model exists on the disk
            assert os.path.isfile(
                env_globals['srl_model_path']), "{} does not exist".format(
                    env_globals['srl_model_path'])
            srl_model_path = env_globals['srl_model_path']
            env_kwargs["srl_model_path"] = srl_model_path

    return train_args, algo_name, algo_class, srl_model_path, env_kwargs
Пример #8
0
def loadConfigAndSetup(load_args):
    """
    Get the training config and setup the parameters
    :param load_args: (Arguments)
    :return: (dict, str, str, str, dict)
    """
    algo_name = ""
    for algo in list(registered_rl.keys()):
        if algo in load_args.log_dir:
            algo_name = algo
            break
    algo_class, algo_type, _ = registered_rl[algo_name]
    if algo_type == AlgoType.OTHER:
        raise ValueError(algo_name + " is not supported for replay")
    printGreen("\n" + algo_name + "\n")

    load_path = "{}/{}_model.pkl".format(load_args.log_dir, algo_name)

    env_globals = json.load(open(load_args.log_dir + "env_globals.json", 'r'))
    train_args = json.load(open(load_args.log_dir + "args.json", 'r'))
    if train_args.get("img_shape", None) is None:
        img_shape = None  #(3,224,224)
    else:
        img_shape = tuple(
            map(int,
                train_args.get("img_shape", None)[1:-1].split(",")))

    env_kwargs = {
        "renders": load_args.render,
        "shape_reward": load_args.shape_reward,  # Reward sparse or shaped
        "action_joints": train_args["action_joints"],
        "is_discrete": not train_args["continuous_actions"],
        "random_target": train_args.get('random_target', False),
        "srl_model": train_args["srl_model"],
        "img_shape": img_shape
        # "img_shape" : train_args.get("img_shape", None)
    }

    # load it, if it was defined
    if "action_repeat" in env_globals:
        env_kwargs["action_repeat"] = env_globals['action_repeat']

    # Remove up action
    if train_args["env"] == "Kuka2ButtonGymEnv-v0":
        env_kwargs["force_down"] = env_globals.get('force_down', True)
    else:
        env_kwargs["force_down"] = env_globals.get('force_down', False)

    if train_args["env"] == "OmnirobotEnv-v0":
        env_kwargs["simple_continual_target"] = env_globals.get(
            "simple_continual_target", False)
        env_kwargs["circular_continual_move"] = env_globals.get(
            "circular_continual_move", False)
        env_kwargs["square_continual_move"] = env_globals.get(
            "square_continual_move", False)
        env_kwargs["eight_continual_move"] = env_globals.get(
            "eight_continual_move", False)

        # If overriding the environment for specific Continual Learning tasks
        if sum([
                load_args.simple_continual, load_args.circular_continual,
                load_args.square_continual
        ]) >= 1:
            env_kwargs["simple_continual_target"] = load_args.simple_continual
            env_kwargs[
                "circular_continual_move"] = load_args.circular_continual
            env_kwargs["square_continual_move"] = load_args.square_continual
            env_kwargs["random_target"] = not (load_args.circular_continual
                                               or load_args.square_continual)

    srl_model_path = None
    if train_args["srl_model"] != "raw_pixels":
        train_args["policy"] = "mlp"
        path = env_globals.get('srl_model_path')

        if path is not None:
            env_kwargs["use_srl"] = True
            # Check that the srl saved model exists on the disk
            assert os.path.isfile(
                env_globals['srl_model_path']), "{} does not exist".format(
                    env_globals['srl_model_path'])
            srl_model_path = env_globals['srl_model_path']
            env_kwargs["srl_model_path"] = srl_model_path

    return train_args, load_path, algo_name, algo_class, srl_model_path, env_kwargs