def main(): parser = argparse.ArgumentParser(description="Hyperparameter search for implemented RL models") parser.add_argument('--optimizer', default='hyperband', choices=['hyperband', 'hyperopt'], type=str, help='The hyperparameter optimizer to choose from') parser.add_argument('--algo', default='ppo2', choices=list(registered_rl.keys()), help='OpenAI baseline to use', type=str) parser.add_argument('--env', type=str, help='environment ID', default='KukaButtonGymEnv-v0', choices=list(registered_env.keys())) parser.add_argument('--seed', type=int, default=0, help='random seed (default: 0)') parser.add_argument('--srl-model', type=str, default='raw_pixels', choices=list(registered_srl.keys()), help='SRL model to use') parser.add_argument('--num-timesteps', type=int, default=1e6, help='number of timesteps the baseline should run') parser.add_argument('-v', '--verbose', action='store_true', default=False, help='Display baseline STDOUT') parser.add_argument('--max-eval', type=int, default=100, help='Number of evalutation to try for hyperopt') args, train_args = parser.parse_known_args() args.log_dir = "logs/_{}_search/".format(args.optimizer) train_args.extend(['--srl-model', args.srl_model, '--seed', str(args.seed), '--algo', args.algo, '--env', args.env, '--log-dir', args.log_dir, '--no-vis']) # verify the algorithm has defined it, and that it returnes an expected value try: opt_param = registered_rl[args.algo][0].getOptParam() assert opt_param is not None except AttributeError or AssertionError: raise AssertionError("Error: {} algo does not support hyperparameter search.".format(args.algo)) if args.optimizer == "hyperband": opt = Hyperband(opt_param, makeRlTrainingFunction(args, train_args), seed=args.seed, max_iter=args.num_timesteps // ITERATION_SCALE) elif args.optimizer == "hyperopt": opt = Hyperopt(opt_param, makeRlTrainingFunction(args, train_args), seed=args.seed, num_eval=args.max_eval) else: raise ValueError("Error: optimizer {} was defined but not implemented, Halting.".format(args.optimizer)) t_start = time.time() opt.run() all_params, loss = zip(*opt.history) idx = np.argmin(loss) opt_params, nb_iter = all_params[idx] reward = loss[idx] print('\ntime to run : {}s'.format(int(time.time() - t_start))) print('Total nb. evaluations : {}'.format(len(all_params))) if nb_iter is not None: print('Best nb. of iterations : {}'.format(int(nb_iter))) print('Best params : ') pprint.pprint(opt_params) print('Best reward : {:.3f}'.format(-reward)) param_dict, timesteps = zip(*all_params) output = pd.DataFrame(list(param_dict)) # make sure we returned a timestep value to log, otherwise ignore if not any([el is None for el in timesteps]): output["timesteps"] = np.array(np.maximum(MIN_ITERATION, np.array(timesteps) * ITERATION_SCALE).astype(int)) output["reward"] = -np.array(loss) output.to_csv("logs/{}_{}_{}_{}_seed{}_numtimestep{}.csv" .format(args.optimizer, args.algo, args.env, args.srl_model, args.seed, args.num_timesteps))
def loadConfigAndSetup(load_args): """ Get the training config and setup the parameters :param load_args: (Arguments) :return: (dict, str, str, str, dict) """ algo_name = "" for algo in list(registered_rl.keys()): if algo in load_args.log_dir: algo_name = algo break algo_class, algo_type, _ = registered_rl[algo_name] if algo_type == AlgoType.OTHER: raise ValueError(algo_name + " is not supported for replay") printGreen("\n" + algo_name + "\n") load_path = "{}/{}_model.pkl".format(load_args.log_dir, algo_name) env_globals = json.load(open(load_args.log_dir + "env_globals.json", 'r')) train_args = json.load(open(load_args.log_dir + "args.json", 'r')) env_kwargs = { "renders": load_args.render, "shape_reward": load_args.shape_reward, # Reward sparse or shaped "action_joints": train_args["action_joints"], "is_discrete": not train_args["continuous_actions"], "random_target": train_args.get('random_target', False), "srl_model": train_args["srl_model"] } # load it, if it was defined if "action_repeat" in env_globals: env_kwargs["action_repeat"] = env_globals['action_repeat'] # Remove up action if train_args["env"] == "Kuka2ButtonGymEnv-v0": env_kwargs["force_down"] = env_globals.get('force_down', True) else: env_kwargs["force_down"] = env_globals.get('force_down', False) srl_model_path = None if train_args["srl_model"] != "raw_pixels": train_args["policy"] = "mlp" path = env_globals.get('srl_model_path') if path is not None: env_kwargs["use_srl"] = True # Check that the srl saved model exists on the disk assert os.path.isfile( env_globals['srl_model_path']), "{} does not exist".format( env_globals['srl_model_path']) srl_model_path = env_globals['srl_model_path'] env_kwargs["srl_model_path"] = srl_model_path return train_args, load_path, algo_name, algo_class, srl_model_path, env_kwargs
def main(): # Global variables for callback global ENV_NAME, ALGO, ALGO_NAME, LOG_INTERVAL, VISDOM_PORT, viz global SAVE_INTERVAL, EPISODE_WINDOW, MIN_EPISODES_BEFORE_SAVE parser = argparse.ArgumentParser( description="Train script for RL algorithms") parser.add_argument('--algo', default='ppo2', choices=list(registered_rl.keys()), help='RL algo to use', type=str) parser.add_argument('--env', type=str, help='environment ID', default='KukaButtonGymEnv-v0', choices=list(registered_env.keys())) parser.add_argument('--seed', type=int, default=0, help='random seed (default: 0)') parser.add_argument( '--episode-window', type=int, default=40, help='Episode window for moving average plot (default: 40)') parser.add_argument( '--log-dir', default='/tmp/gym/', type=str, help='directory to save agent logs and model (default: /tmp/gym)') parser.add_argument('--num-timesteps', type=int, default=int(1e6)) parser.add_argument('--srl-model', type=str, default='raw_pixels', choices=list(registered_srl.keys()), help='SRL model to use') parser.add_argument('--num-stack', type=int, default=1, help='number of frames to stack (default: 1)') parser.add_argument( '--action-repeat', type=int, default=1, help='number of times an action will be repeated (default: 1)') parser.add_argument('--port', type=int, default=8097, help='visdom server port (default: 8097)') parser.add_argument('--no-vis', action='store_true', default=False, help='disables visdom visualization') parser.add_argument( '--shape-reward', action='store_true', default=False, help='Shape the reward (reward = - distance) instead of a sparse reward' ) parser.add_argument('-c', '--continuous-actions', action='store_true', default=False) parser.add_argument( '-joints', '--action-joints', action='store_true', default=False, help= 'set actions to the joints of the arm directly, instead of inverse kinematics' ) parser.add_argument('-r', '--random-target', action='store_true', default=False, help='Set the button to a random position') parser.add_argument( '--srl-config-file', type=str, default="config/srl_models.yaml", help='Set the location of the SRL model path configuration.') parser.add_argument('--hyperparam', type=str, nargs='+', default=[]) parser.add_argument('--min-episodes-save', type=int, default=100, help="Min number of episodes before saving best model") parser.add_argument( '--latest', action='store_true', default=False, help= 'load the latest learned model (location:srl_zoo/logs/DatasetName/)') parser.add_argument( '--load-rl-model-path', type=str, default=None, help="load the trained RL model, should be with the same algorithm type" ) parser.add_argument( '-sc', '--simple-continual', action='store_true', default=False, help= 'Simple red square target for task 1 of continual learning scenario. ' + 'The task is: robot should reach the target.') parser.add_argument( '-cc', '--circular-continual', action='store_true', default=False, help='Blue square target for task 2 of continual learning scenario. ' + 'The task is: robot should turn in circle around the target.') parser.add_argument( '-sqc', '--square-continual', action='store_true', default=False, help='Green square target for task 3 of continual learning scenario. ' + 'The task is: robot should turn in square around the target.') parser.add_argument( '-ec', '--eight-continual', action='store_true', default=False, help='Green square target for task 4 of continual learning scenario. ' + 'The task is: robot should do the eigth with the target as center of the shape.' ) parser.add_argument('--teacher-data-folder', type=str, default="", help='Dataset folder of the teacher(s) policy(ies)', required=False) parser.add_argument( '--epochs-distillation', type=int, default=30, metavar='N', help='number of epochs to train for distillation(default: 30)') parser.add_argument( '--distillation-training-set-size', type=int, default=-1, help='Limit size (number of samples) of the training set (default: -1)' ) parser.add_argument( '--perform-cross-evaluation-cc', action='store_true', default=False, help='A cross evaluation from the latest stored model to all tasks') parser.add_argument( '--eval-episode-window', type=int, default=400, metavar='N', help= 'Episode window for saving each policy checkpoint for future distillation(default: 100)' ) parser.add_argument( '--new-lr', type=float, default=1.e-4, help="New learning rate ratio to train a pretrained agent") parser.add_argument('--img-shape', type=str, default="(3,64,64)", help="Image shape of environment.") parser.add_argument( "--gpu-num", help="Choose the number of GPU (CUDA_VISIBLE_DEVICES).", type=str, default="1", choices=["0", "1", "2", "3", "5", "6", "7", "8"]) parser.add_argument("--srl-model-path", help="SRL model weights path", type=str, default=None) parser.add_argument( "--relative-pos", action='store_true', default=False, help="For 'ground_truth': use relative position or not.") # Ignore unknown args for now args, unknown = parser.parse_known_args() # os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num env_kwargs = {} if args.img_shape is None: img_shape = None #(3,224,224) else: img_shape = tuple(map(int, args.img_shape[1:-1].split(","))) env_kwargs['img_shape'] = img_shape # LOAD SRL models list assert os.path.exists(args.srl_config_file), \ "Error: cannot load \"--srl-config-file {}\", file not found!".format(args.srl_config_file) with open(args.srl_config_file, 'rb') as f: all_models = yaml.load(f) # Sanity check assert args.episode_window >= 1, "Error: --episode_window cannot be less than 1" assert args.num_timesteps >= 1, "Error: --num-timesteps cannot be less than 1" assert args.num_stack >= 1, "Error: --num-stack cannot be less than 1" assert args.action_repeat >= 1, "Error: --action-repeat cannot be less than 1" assert 0 <= args.port < 65535, "Error: invalid visdom port number {}, ".format(args.port) + \ "port number must be an unsigned 16bit number [0,65535]." assert registered_srl[args.srl_model][0] == SRLType.ENVIRONMENT or args.env in all_models, \ "Error: the environment {} has no srl_model defined in 'srl_models.yaml'. Cannot continue.".format(args.env) # check that all the SRL_model can be run on the environment if registered_srl[args.srl_model][1] is not None: found = False for compatible_class in registered_srl[args.srl_model][1]: if issubclass(compatible_class, registered_env[args.env][0]): found = True break assert found, "Error: srl_model {}, is not compatible with the {} environment.".format( args.srl_model, args.env) assert not(sum([args.simple_continual, args.circular_continual, args.square_continual, args.eight_continual]) \ > 1 and args.env == "OmnirobotEnv-v0"), \ "For continual SRL and RL, please provide only one scenario at the time and use OmnirobotEnv-v0 environment !" assert not(args.algo == "distillation" and (args.teacher_data_folder == '' or args.continuous_actions is True)), \ "For performing policy distillation, make sure use specify a valid teacher dataset and discrete actions !" ENV_NAME = args.env ALGO_NAME = args.algo VISDOM_PORT = args.port EPISODE_WINDOW = args.episode_window MIN_EPISODES_BEFORE_SAVE = args.min_episodes_save CROSS_EVAL = args.perform_cross_evaluation_cc EPISODE_WINDOW_DISTILLATION_WIN = args.eval_episode_window NEW_LR = args.new_lr print("EPISODE_WINDOW_DISTILLATION_WIN: ", EPISODE_WINDOW_DISTILLATION_WIN) if args.no_vis: viz = False algo_class, algo_type, action_type = registered_rl[args.algo] algo = algo_class() ALGO = algo # if callback frequency needs to be changed LOG_INTERVAL = algo.LOG_INTERVAL SAVE_INTERVAL = algo.SAVE_INTERVAL if not args.continuous_actions and ActionType.DISCRETE not in action_type: raise ValueError( args.algo + " does not support discrete actions, please use the '--continuous-actions' " + "(or '-c') flag.") if args.continuous_actions and ActionType.CONTINUOUS not in action_type: raise ValueError( args.algo + " does not support continuous actions, please remove the '--continuous-actions' " + "(or '-c') flag.") env_kwargs["is_discrete"] = not args.continuous_actions printGreen("\nAgent = {} \n".format(args.algo)) env_kwargs["action_repeat"] = args.action_repeat # Random init position for button env_kwargs["random_target"] = args.random_target # If in simple continual scenario, then the target should be initialized randomly. if args.simple_continual is True: env_kwargs["random_target"] = True # Allow up action # env_kwargs["force_down"] = False # allow multi-view env_kwargs['multi_view'] = args.srl_model == "multi_view_srl" parser = algo.customArguments(parser) args = parser.parse_args() args, env_kwargs = configureEnvAndLogFolder(args, env_kwargs, all_models) args_dict = filterJSONSerializableObjects(vars(args)) # Save args with open(LOG_DIR + "args.json", "w") as f: json.dump(args_dict, f) env_class = registered_env[args.env][0] # env default kwargs default_env_kwargs = { k: v.default for k, v in inspect.signature(env_class.__init__).parameters.items() if v is not None } globals_env_param = sys.modules[env_class.__module__].getGlobals() ### HACK way to reset image shape !! globals_env_param['RENDER_HEIGHT'] = img_shape[1] globals_env_param['RENDER_WIDTH'] = img_shape[2] globals_env_param['RELATIVE_POS'] = args.relative_pos super_class = registered_env[args.env][1] # reccursive search through all the super classes of the asked environment, in order to get all the arguments. rec_super_class_lookup = { dict_class: dict_super_class for _, (dict_class, dict_super_class, _, _) in registered_env.items() } while super_class != SRLGymEnv: assert super_class in rec_super_class_lookup, "Error: could not find super class of {}".format(super_class) + \ ", are you sure \"registered_env\" is correctly defined?" super_env_kwargs = { k: v.default for k, v in inspect.signature( super_class.__init__).parameters.items() if v is not None } default_env_kwargs = {**super_env_kwargs, **default_env_kwargs} globals_env_param = { **sys.modules[super_class.__module__].getGlobals(), **globals_env_param } super_class = rec_super_class_lookup[super_class] # Print Variables printYellow("Arguments:") pprint(args_dict) printYellow("Env Globals:") pprint( filterJSONSerializableObjects({ **globals_env_param, **default_env_kwargs, **env_kwargs })) # Save env params saveEnvParams(globals_env_param, {**default_env_kwargs, **env_kwargs}) # Seed tensorflow, python and numpy random generator set_global_seeds(args.seed) # Augment the number of timesteps (when using mutliprocessing this number is not reached) args.num_timesteps = int(1.1 * args.num_timesteps) # Get the hyperparameter, if given (Hyperband) hyperparams = { param.split(":")[0]: param.split(":")[1] for param in args.hyperparam } hyperparams = algo.parserHyperParam(hyperparams) if args.load_rl_model_path is not None: #use a small learning rate print("use a small learning rate: {:f}".format(1.0e-4)) hyperparams["learning_rate"] = lambda f: f * 1.0e-4 # Train the agent if args.load_rl_model_path is not None: algo.setLoadPath(args.load_rl_model_path) algo.train(args, callback, env_kwargs=env_kwargs, train_kwargs=hyperparams)
def main(): parser = argparse.ArgumentParser( description="OpenAI RL Baselines Benchmark", epilog= 'After the arguments are parsed, the rest are assumed to be arguments for' + ' rl_baselines.train') parser.add_argument('--algo', type=str, default='ppo2', help='OpenAI baseline to use', choices=list(registered_rl.keys())) parser.add_argument('--env', type=str, nargs='+', default=["KukaButtonGymEnv-v0"], help='environment ID(s)', choices=list(registered_env.keys())) parser.add_argument('--srl-model', type=str, nargs='+', default=["raw_pixels"], help='SRL model(s) to use', choices=list(registered_srl.keys())) parser.add_argument('--num-timesteps', type=int, default=1e6, help='number of timesteps the baseline should run') parser.add_argument('-v', '--verbose', action='store_true', default=False, help='Display baseline STDOUT') parser.add_argument( '--num-iteration', type=int, default=15, help= 'number of time each algorithm should be run for each unique combination of environment ' + ' and srl-model.') parser.add_argument( '--seed', type=int, default=0, help= 'initial seed for each unique combination of environment and srl-model.' ) parser.add_argument( '--srl-config-file', type=str, default="config/srl_models.yaml", help='Set the location of the SRL model path configuration.') # returns the parsed arguments, and the rest are assumed to be arguments for rl_baselines.train args, train_args = parser.parse_known_args() # Sanity check assert args.num_timesteps >= 1, "Error: --num-timesteps cannot be less than 1" assert args.num_iteration >= 1, "Error: --num-iteration cannot be less than 1" # Removing duplicates and sort srl_models = list(set(args.srl_model)) envs = list(set(args.env)) srl_models.sort() envs.sort() # LOAD SRL models list assert os.path.exists(args.srl_config_file), \ "Error: cannot load \"--srl-config-file {}\", file not found!".format(args.srl_config_file) with open(args.srl_config_file, 'rb') as f: all_models = yaml.load(f) # Checking definition and presence of all requested srl_models valid = True for env in envs: # validated the env definition if env not in all_models: printRed( "Error: 'srl_models.yaml' missing definition for environment {}" .format(env)) valid = False continue # skip to the next env, this one is not valid # checking log_folder for current env missing_log = "log_folder" not in all_models[env] if missing_log: printRed( "Error: 'srl_models.yaml' missing definition for log_folder in environment {}" .format(env)) valid = False # validate each model for the current env definition for model in srl_models: if registered_srl[model][0] == SRLType.ENVIRONMENT: continue # not an srl model, skip to the next model elif model not in all_models[env]: printRed( "Error: 'srl_models.yaml' missing srl_model {} for environment {}" .format(model, env)) valid = False elif (not missing_log) and ( not os.path.exists(all_models[env]["log_folder"] + all_models[env][model])): # checking presence of srl_model path, if and only if log_folder exists printRed( "Error: srl_model {} for environment {} was defined in ". format(model, env) + "'srl_models.yaml', however the file {} it was tagetting does not exist." .format(all_models[env]["log_folder"] + all_models[env][model])) valid = False assert valid, "Errors occured due to malformed 'srl_models.yaml', cannot continue." # check that all the SRL_models can be run on all the environments valid = True for env in envs: for model in srl_models: if registered_srl[model][1] is not None: found = False for compatible_class in registered_srl[model][1]: if issubclass(compatible_class, registered_env[env][0]): found = True break if not found: valid = False printRed( "Error: srl_model {}, is not compatible with the {} environment." .format(model, env)) assert valid, "Errors occured due to an incompatible combination of srl_model and environment, cannot continue." # the seeds used in training the baseline. seeds = list(np.arange(args.num_iteration) + args.seed) if args.verbose: # None here means stdout of terminal for subprocess.call stdout = None else: stdout = open(os.devnull, 'w') printGreen("\nRunning {} benchmarks {} times...".format( args.algo, args.num_iteration)) print("\nSRL-Models:\t{}".format(srl_models)) print("environments:\t{}".format(envs)) print("verbose:\t{}".format(args.verbose)) print("timesteps:\t{}".format(args.num_timesteps)) for model in srl_models: for env in envs: for i in range(args.num_iteration): printGreen( "\nIteration_num={} (seed: {}), Environment='{}', SRL-Model='{}'" .format(i, seeds[i], env, model)) # redefine the parsed args for rl_baselines.train loop_args = [ '--srl-model', model, '--seed', str(seeds[i]), '--algo', args.algo, '--env', env, '--num-timesteps', str(int(args.num_timesteps)), '--srl-config-file', args.srl_config_file ] ok = subprocess.call(['python', '-m', 'rl_baselines.train'] + train_args + loop_args, stdout=stdout) if ok != 0: # throw the error down to the terminal raise ChildProcessError( "An error occured, error code: {}".format(ok))
def main(): # Global variables for callback global ENV_NAME, ALGO, ALGO_NAME, LOG_INTERVAL, VISDOM_PORT, viz global SAVE_INTERVAL, EPISODE_WINDOW, MIN_EPISODES_BEFORE_SAVE parser = argparse.ArgumentParser(description="Train script for RL algorithms") parser.add_argument('--algo', default='ppo2', choices=list(registered_rl.keys()), help='RL algo to use', type=str) parser.add_argument('--env', type=str, help='environment ID', default='KukaButtonGymEnv-v0', choices=list(registered_env.keys())) parser.add_argument('--seed', type=int, default=0, help='random seed (default: 0)') parser.add_argument('--episode_window', type=int, default=40, help='Episode window for moving average plot (default: 40)') parser.add_argument('--log-dir', default='/tmp/gym/', type=str, help='directory to save agent logs and model (default: /tmp/gym)') parser.add_argument('--num-timesteps', type=int, default=int(1e6)) parser.add_argument('--srl-model', type=str, default='raw_pixels', choices=list(registered_srl.keys()), help='SRL model to use') parser.add_argument('--num-stack', type=int, default=1, help='number of frames to stack (default: 1)') parser.add_argument('--action-repeat', type=int, default=1, help='number of times an action will be repeated (default: 1)') parser.add_argument('--port', type=int, default=8097, help='visdom server port (default: 8097)') parser.add_argument('--no-vis', action='store_true', default=False, help='disables visdom visualization') parser.add_argument('--shape-reward', action='store_true', default=False, help='Shape the reward (reward = - distance) instead of a sparse reward') parser.add_argument('-c', '--continuous-actions', action='store_true', default=False) parser.add_argument('-joints', '--action-joints', action='store_true', default=False, help='set actions to the joints of the arm directly, instead of inverse kinematics') parser.add_argument('-r', '--random-target', action='store_true', default=False, help='Set the button to a random position') parser.add_argument('--srl-config-file', type=str, default="config/srl_models.yaml", help='Set the location of the SRL model path configuration.') parser.add_argument('--hyperparam', type=str, nargs='+', default=[]) parser.add_argument('--min-episodes-save', type=int, default=100, help="Min number of episodes before saving best model") parser.add_argument('--latest', action='store_true', default=False, help='load the latest learned model (location:srl_zoo/logs/DatasetName/)') parser.add_argument('--load-rl-model-path', type=str, default=None, help="load the trained RL model, should be with the same algorithm type") # Ignore unknown args for now args, unknown = parser.parse_known_args() env_kwargs = {} # LOAD SRL models list assert os.path.exists(args.srl_config_file), \ "Error: cannot load \"--srl-config-file {}\", file not found!".format(args.srl_config_file) with open(args.srl_config_file, 'rb') as f: all_models = yaml.load(f) # Sanity check assert args.episode_window >= 1, "Error: --episode_window cannot be less than 1" assert args.num_timesteps >= 1, "Error: --num-timesteps cannot be less than 1" assert args.num_stack >= 1, "Error: --num-stack cannot be less than 1" assert args.action_repeat >= 1, "Error: --action-repeat cannot be less than 1" assert 0 <= args.port < 65535, "Error: invalid visdom port number {}, ".format(args.port) + \ "port number must be an unsigned 16bit number [0,65535]." assert registered_srl[args.srl_model][0] == SRLType.ENVIRONMENT or args.env in all_models, \ "Error: the environment {} has no srl_model defined in 'srl_models.yaml'. Cannot continue.".format(args.env) # check that all the SRL_model can be run on the environment if registered_srl[args.srl_model][1] is not None: found = False for compatible_class in registered_srl[args.srl_model][1]: if issubclass(compatible_class, registered_env[args.env][0]): found = True break assert found, "Error: srl_model {}, is not compatible with the {} environment.".format(args.srl_model, args.env) ENV_NAME = args.env ALGO_NAME = args.algo VISDOM_PORT = args.port EPISODE_WINDOW = args.episode_window MIN_EPISODES_BEFORE_SAVE = args.min_episodes_save if args.no_vis: viz = False algo_class, algo_type, action_type = registered_rl[args.algo] algo = algo_class() ALGO = algo # if callback frequency needs to be changed LOG_INTERVAL = algo.LOG_INTERVAL SAVE_INTERVAL = algo.SAVE_INTERVAL if not args.continuous_actions and ActionType.DISCRETE not in action_type: raise ValueError(args.algo + " does not support discrete actions, please use the '--continuous-actions' " + "(or '-c') flag.") if args.continuous_actions and ActionType.CONTINUOUS not in action_type: raise ValueError(args.algo + " does not support continuous actions, please remove the '--continuous-actions' " + "(or '-c') flag.") env_kwargs["is_discrete"] = not args.continuous_actions printGreen("\nAgent = {} \n".format(args.algo)) env_kwargs["action_repeat"] = args.action_repeat # Random init position for button env_kwargs["random_target"] = args.random_target # Allow up action # env_kwargs["force_down"] = False # allow multi-view env_kwargs['multi_view'] = args.srl_model == "multi_view_srl" parser = algo.customArguments(parser) args = parser.parse_args() args, env_kwargs = configureEnvAndLogFolder(args, env_kwargs, all_models) args_dict = filterJSONSerializableObjects(vars(args)) # Save args with open(LOG_DIR + "args.json", "w") as f: json.dump(args_dict, f) env_class = registered_env[args.env][0] # env default kwargs default_env_kwargs = {k: v.default for k, v in inspect.signature(env_class.__init__).parameters.items() if v is not None} globals_env_param = sys.modules[env_class.__module__].getGlobals() super_class = registered_env[args.env][1] # reccursive search through all the super classes of the asked environment, in order to get all the arguments. rec_super_class_lookup = {dict_class: dict_super_class for _, (dict_class, dict_super_class, _, _) in registered_env.items()} while super_class != SRLGymEnv: assert super_class in rec_super_class_lookup, "Error: could not find super class of {}".format(super_class) + \ ", are you sure \"registered_env\" is correctly defined?" super_env_kwargs = {k: v.default for k, v in inspect.signature(super_class.__init__).parameters.items() if v is not None} default_env_kwargs = {**super_env_kwargs, **default_env_kwargs} globals_env_param = {**sys.modules[super_class.__module__].getGlobals(), **globals_env_param} super_class = rec_super_class_lookup[super_class] # Print Variables printYellow("Arguments:") pprint(args_dict) printYellow("Env Globals:") pprint(filterJSONSerializableObjects({**globals_env_param, **default_env_kwargs, **env_kwargs})) # Save env params saveEnvParams(globals_env_param, {**default_env_kwargs, **env_kwargs}) # Seed tensorflow, python and numpy random generator set_global_seeds(args.seed) # Augment the number of timesteps (when using mutliprocessing this number is not reached) args.num_timesteps = int(1.1 * args.num_timesteps) # Get the hyperparameter, if given (Hyperband) hyperparams = {param.split(":")[0]: param.split(":")[1] for param in args.hyperparam} hyperparams = algo.parserHyperParam(hyperparams) if args.load_rl_model_path is not None: #use a small learning rate print("use a small learning rate: {:f}".format(1.0e-4)) hyperparams["learning_rate"] = lambda f: f * 1.0e-4 # Train the agent if args.load_rl_model_path is not None: algo.setLoadPath(args.load_rl_model_path) algo.train(args, callback, env_kwargs=env_kwargs, train_kwargs=hyperparams)
def loadConfigAndSetup(load_args): """ Get the training config and setup the parameters :param load_args: (Arguments) :return: (dict, str, str, str, dict) """ algo_name = "" for algo in list(registered_rl.keys()): if algo in load_args.log_dir: algo_name = algo break algo_class, algo_type, _ = registered_rl[algo_name] if algo_type == AlgoType.OTHER: raise ValueError(algo_name + " is not supported for replay") printGreen("\n" + algo_name + "\n") try: # If args contains episode information, this is for student_evaluation (distillation) if not load_args.episode == -1: load_path = "{}/{}_{}_model.pkl".format(load_args.log_dir, algo_name, load_args.episode,) else: load_path = "{}/{}_model.pkl".format(load_args.log_dir, algo_name) except: printYellow( "No episode of checkpoint specified, go for the default policy model: {}_model.pkl".format(algo_name)) if load_args.log_dir[-3:] != 'pkl': load_path = "{}/{}_model.pkl".format(load_args.log_dir, algo_name) else: load_path = load_args.log_dir load_args.log_dir = os.path.dirname(load_path)+'/' env_globals = json.load(open(load_args.log_dir + "env_globals.json", 'r')) train_args = json.load(open(load_args.log_dir + "args.json", 'r')) env_kwargs = { "renders": load_args.render, "shape_reward": load_args.shape_reward, # Reward sparse or shaped "action_joints": train_args["action_joints"], "is_discrete": not train_args["continuous_actions"], "random_target": train_args.get('random_target', False), "srl_model": train_args["srl_model"] } # load it, if it was defined if "action_repeat" in env_globals: env_kwargs["action_repeat"] = env_globals['action_repeat'] # Remove up action if train_args["env"] == "Kuka2ButtonGymEnv-v0": env_kwargs["force_down"] = env_globals.get('force_down', True) else: env_kwargs["force_down"] = env_globals.get('force_down', False) if train_args["env"] == "OmnirobotEnv-v0": env_kwargs["simple_continual_target"] = env_globals.get("simple_continual_target", False) env_kwargs["circular_continual_move"] = env_globals.get("circular_continual_move", False) env_kwargs["square_continual_move"] = env_globals.get("square_continual_move", False) env_kwargs["eight_continual_move"] = env_globals.get("eight_continual_move", False) # If overriding the environment for specific Continual Learning tasks if sum([load_args.simple_continual, load_args.circular_continual, load_args.square_continual]) >= 1: env_kwargs["simple_continual_target"] = load_args.simple_continual env_kwargs["circular_continual_move"] = load_args.circular_continual env_kwargs["square_continual_move"] = load_args.square_continual env_kwargs["random_target"] = not (load_args.circular_continual or load_args.square_continual) srl_model_path = None if train_args["srl_model"] != "raw_pixels": train_args["policy"] = "mlp" path = env_globals.get('srl_model_path') if path is not None: env_kwargs["use_srl"] = True # Check that the srl saved model exists on the disk assert os.path.isfile(env_globals['srl_model_path']), \ "{} does not exist".format(env_globals['srl_model_path']) srl_model_path = env_globals['srl_model_path'] env_kwargs["srl_model_path"] = srl_model_path return train_args, load_path, algo_name, algo_class, srl_model_path, env_kwargs
def loadConfigAndSetup(log_dir): """ load training variable from a pre-trained model :param log_dir: the path where the model is located, example: logs/sc2cc/OmnirobotEnv-v0/srl_combination/ppo2/19-05-07_11h32_39 :return: train_args, algo_name, algo_class(stable_baselines.PPO2), srl_model_path, env_kwargs """ algo_name = "" for algo in list(registered_rl.keys()): if algo in log_dir: algo_name = algo break algo_class, algo_type, _ = registered_rl[algo_name] if algo_type == AlgoType.OTHER: raise ValueError(algo_name + " is not supported for evaluation") env_globals = json.load(open(log_dir + "env_globals.json", 'r')) train_args = json.load(open(log_dir + "args.json", 'r')) env_kwargs = { "renders": False, "shape_reward": False, #TODO, since we dont use simple target, we should elimanate this choice? "action_joints": train_args["action_joints"], "is_discrete": not train_args["continuous_actions"], "random_target": train_args.get('random_target', False), "srl_model": train_args["srl_model"] } # load it, if it was defined if "action_repeat" in env_globals: env_kwargs["action_repeat"] = env_globals['action_repeat'] # Remove up action if train_args["env"] == "Kuka2ButtonGymEnv-v0": env_kwargs["force_down"] = env_globals.get('force_down', True) else: env_kwargs["force_down"] = env_globals.get('force_down', False) if train_args["env"] == "OmnirobotEnv-v0": env_kwargs["simple_continual_target"] = env_globals.get( "simple_continual_target", False) env_kwargs["circular_continual_move"] = env_globals.get( "circular_continual_move", False) env_kwargs["square_continual_move"] = env_globals.get( "square_continual_move", False) env_kwargs["eight_continual_move"] = env_globals.get( "eight_continual_move", False) srl_model_path = None if train_args["srl_model"] != "raw_pixels": train_args["policy"] = "mlp" path = env_globals.get('srl_model_path') if path is not None: env_kwargs["use_srl"] = True # Check that the srl saved model exists on the disk assert os.path.isfile( env_globals['srl_model_path']), "{} does not exist".format( env_globals['srl_model_path']) srl_model_path = env_globals['srl_model_path'] env_kwargs["srl_model_path"] = srl_model_path return train_args, algo_name, algo_class, srl_model_path, env_kwargs
def loadConfigAndSetup(load_args): """ Get the training config and setup the parameters :param load_args: (Arguments) :return: (dict, str, str, str, dict) """ algo_name = "" for algo in list(registered_rl.keys()): if algo in load_args.log_dir: algo_name = algo break algo_class, algo_type, _ = registered_rl[algo_name] if algo_type == AlgoType.OTHER: raise ValueError(algo_name + " is not supported for replay") printGreen("\n" + algo_name + "\n") load_path = "{}/{}_model.pkl".format(load_args.log_dir, algo_name) env_globals = json.load(open(load_args.log_dir + "env_globals.json", 'r')) train_args = json.load(open(load_args.log_dir + "args.json", 'r')) if train_args.get("img_shape", None) is None: img_shape = None #(3,224,224) else: img_shape = tuple( map(int, train_args.get("img_shape", None)[1:-1].split(","))) env_kwargs = { "renders": load_args.render, "shape_reward": load_args.shape_reward, # Reward sparse or shaped "action_joints": train_args["action_joints"], "is_discrete": not train_args["continuous_actions"], "random_target": train_args.get('random_target', False), "srl_model": train_args["srl_model"], "img_shape": img_shape # "img_shape" : train_args.get("img_shape", None) } # load it, if it was defined if "action_repeat" in env_globals: env_kwargs["action_repeat"] = env_globals['action_repeat'] # Remove up action if train_args["env"] == "Kuka2ButtonGymEnv-v0": env_kwargs["force_down"] = env_globals.get('force_down', True) else: env_kwargs["force_down"] = env_globals.get('force_down', False) if train_args["env"] == "OmnirobotEnv-v0": env_kwargs["simple_continual_target"] = env_globals.get( "simple_continual_target", False) env_kwargs["circular_continual_move"] = env_globals.get( "circular_continual_move", False) env_kwargs["square_continual_move"] = env_globals.get( "square_continual_move", False) env_kwargs["eight_continual_move"] = env_globals.get( "eight_continual_move", False) # If overriding the environment for specific Continual Learning tasks if sum([ load_args.simple_continual, load_args.circular_continual, load_args.square_continual ]) >= 1: env_kwargs["simple_continual_target"] = load_args.simple_continual env_kwargs[ "circular_continual_move"] = load_args.circular_continual env_kwargs["square_continual_move"] = load_args.square_continual env_kwargs["random_target"] = not (load_args.circular_continual or load_args.square_continual) srl_model_path = None if train_args["srl_model"] != "raw_pixels": train_args["policy"] = "mlp" path = env_globals.get('srl_model_path') if path is not None: env_kwargs["use_srl"] = True # Check that the srl saved model exists on the disk assert os.path.isfile( env_globals['srl_model_path']), "{} does not exist".format( env_globals['srl_model_path']) srl_model_path = env_globals['srl_model_path'] env_kwargs["srl_model_path"] = srl_model_path return train_args, load_path, algo_name, algo_class, srl_model_path, env_kwargs