def main(): parser = argparse.ArgumentParser(description="Hyperparameter search for implemented RL models") parser.add_argument('--optimizer', default='hyperband', choices=['hyperband', 'hyperopt'], type=str, help='The hyperparameter optimizer to choose from') parser.add_argument('--algo', default='ppo2', choices=list(registered_rl.keys()), help='OpenAI baseline to use', type=str) parser.add_argument('--env', type=str, help='environment ID', default='KukaButtonGymEnv-v0', choices=list(registered_env.keys())) parser.add_argument('--seed', type=int, default=0, help='random seed (default: 0)') parser.add_argument('--srl-model', type=str, default='raw_pixels', choices=list(registered_srl.keys()), help='SRL model to use') parser.add_argument('--num-timesteps', type=int, default=1e6, help='number of timesteps the baseline should run') parser.add_argument('-v', '--verbose', action='store_true', default=False, help='Display baseline STDOUT') parser.add_argument('--max-eval', type=int, default=100, help='Number of evalutation to try for hyperopt') args, train_args = parser.parse_known_args() args.log_dir = "logs/_{}_search/".format(args.optimizer) train_args.extend(['--srl-model', args.srl_model, '--seed', str(args.seed), '--algo', args.algo, '--env', args.env, '--log-dir', args.log_dir, '--no-vis']) # verify the algorithm has defined it, and that it returnes an expected value try: opt_param = registered_rl[args.algo][0].getOptParam() assert opt_param is not None except AttributeError or AssertionError: raise AssertionError("Error: {} algo does not support hyperparameter search.".format(args.algo)) if args.optimizer == "hyperband": opt = Hyperband(opt_param, makeRlTrainingFunction(args, train_args), seed=args.seed, max_iter=args.num_timesteps // ITERATION_SCALE) elif args.optimizer == "hyperopt": opt = Hyperopt(opt_param, makeRlTrainingFunction(args, train_args), seed=args.seed, num_eval=args.max_eval) else: raise ValueError("Error: optimizer {} was defined but not implemented, Halting.".format(args.optimizer)) t_start = time.time() opt.run() all_params, loss = zip(*opt.history) idx = np.argmin(loss) opt_params, nb_iter = all_params[idx] reward = loss[idx] print('\ntime to run : {}s'.format(int(time.time() - t_start))) print('Total nb. evaluations : {}'.format(len(all_params))) if nb_iter is not None: print('Best nb. of iterations : {}'.format(int(nb_iter))) print('Best params : ') pprint.pprint(opt_params) print('Best reward : {:.3f}'.format(-reward)) param_dict, timesteps = zip(*all_params) output = pd.DataFrame(list(param_dict)) # make sure we returned a timestep value to log, otherwise ignore if not any([el is None for el in timesteps]): output["timesteps"] = np.array(np.maximum(MIN_ITERATION, np.array(timesteps) * ITERATION_SCALE).astype(int)) output["reward"] = -np.array(loss) output.to_csv("logs/{}_{}_{}_{}_seed{}_numtimestep{}.csv" .format(args.optimizer, args.algo, args.env, args.srl_model, args.seed, args.num_timesteps))
def main(): # Global variables for callback global ENV_NAME, ALGO, ALGO_NAME, LOG_INTERVAL, VISDOM_PORT, viz global SAVE_INTERVAL, EPISODE_WINDOW, MIN_EPISODES_BEFORE_SAVE parser = argparse.ArgumentParser( description="Train script for RL algorithms") parser.add_argument('--algo', default='ppo2', choices=list(registered_rl.keys()), help='RL algo to use', type=str) parser.add_argument('--env', type=str, help='environment ID', default='KukaButtonGymEnv-v0', choices=list(registered_env.keys())) parser.add_argument('--seed', type=int, default=0, help='random seed (default: 0)') parser.add_argument( '--episode-window', type=int, default=40, help='Episode window for moving average plot (default: 40)') parser.add_argument( '--log-dir', default='/tmp/gym/', type=str, help='directory to save agent logs and model (default: /tmp/gym)') parser.add_argument('--num-timesteps', type=int, default=int(1e6)) parser.add_argument('--srl-model', type=str, default='raw_pixels', choices=list(registered_srl.keys()), help='SRL model to use') parser.add_argument('--num-stack', type=int, default=1, help='number of frames to stack (default: 1)') parser.add_argument( '--action-repeat', type=int, default=1, help='number of times an action will be repeated (default: 1)') parser.add_argument('--port', type=int, default=8097, help='visdom server port (default: 8097)') parser.add_argument('--no-vis', action='store_true', default=False, help='disables visdom visualization') parser.add_argument( '--shape-reward', action='store_true', default=False, help='Shape the reward (reward = - distance) instead of a sparse reward' ) parser.add_argument('-c', '--continuous-actions', action='store_true', default=False) parser.add_argument( '-joints', '--action-joints', action='store_true', default=False, help= 'set actions to the joints of the arm directly, instead of inverse kinematics' ) parser.add_argument('-r', '--random-target', action='store_true', default=False, help='Set the button to a random position') parser.add_argument( '--srl-config-file', type=str, default="config/srl_models.yaml", help='Set the location of the SRL model path configuration.') parser.add_argument('--hyperparam', type=str, nargs='+', default=[]) parser.add_argument('--min-episodes-save', type=int, default=100, help="Min number of episodes before saving best model") parser.add_argument( '--latest', action='store_true', default=False, help= 'load the latest learned model (location:srl_zoo/logs/DatasetName/)') parser.add_argument( '--load-rl-model-path', type=str, default=None, help="load the trained RL model, should be with the same algorithm type" ) parser.add_argument( '-sc', '--simple-continual', action='store_true', default=False, help= 'Simple red square target for task 1 of continual learning scenario. ' + 'The task is: robot should reach the target.') parser.add_argument( '-cc', '--circular-continual', action='store_true', default=False, help='Blue square target for task 2 of continual learning scenario. ' + 'The task is: robot should turn in circle around the target.') parser.add_argument( '-sqc', '--square-continual', action='store_true', default=False, help='Green square target for task 3 of continual learning scenario. ' + 'The task is: robot should turn in square around the target.') parser.add_argument( '-ec', '--eight-continual', action='store_true', default=False, help='Green square target for task 4 of continual learning scenario. ' + 'The task is: robot should do the eigth with the target as center of the shape.' ) parser.add_argument('--teacher-data-folder', type=str, default="", help='Dataset folder of the teacher(s) policy(ies)', required=False) parser.add_argument( '--epochs-distillation', type=int, default=30, metavar='N', help='number of epochs to train for distillation(default: 30)') parser.add_argument( '--distillation-training-set-size', type=int, default=-1, help='Limit size (number of samples) of the training set (default: -1)' ) parser.add_argument( '--perform-cross-evaluation-cc', action='store_true', default=False, help='A cross evaluation from the latest stored model to all tasks') parser.add_argument( '--eval-episode-window', type=int, default=400, metavar='N', help= 'Episode window for saving each policy checkpoint for future distillation(default: 100)' ) parser.add_argument( '--new-lr', type=float, default=1.e-4, help="New learning rate ratio to train a pretrained agent") parser.add_argument('--img-shape', type=str, default="(3,64,64)", help="Image shape of environment.") parser.add_argument( "--gpu-num", help="Choose the number of GPU (CUDA_VISIBLE_DEVICES).", type=str, default="1", choices=["0", "1", "2", "3", "5", "6", "7", "8"]) parser.add_argument("--srl-model-path", help="SRL model weights path", type=str, default=None) parser.add_argument( "--relative-pos", action='store_true', default=False, help="For 'ground_truth': use relative position or not.") # Ignore unknown args for now args, unknown = parser.parse_known_args() # os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num env_kwargs = {} if args.img_shape is None: img_shape = None #(3,224,224) else: img_shape = tuple(map(int, args.img_shape[1:-1].split(","))) env_kwargs['img_shape'] = img_shape # LOAD SRL models list assert os.path.exists(args.srl_config_file), \ "Error: cannot load \"--srl-config-file {}\", file not found!".format(args.srl_config_file) with open(args.srl_config_file, 'rb') as f: all_models = yaml.load(f) # Sanity check assert args.episode_window >= 1, "Error: --episode_window cannot be less than 1" assert args.num_timesteps >= 1, "Error: --num-timesteps cannot be less than 1" assert args.num_stack >= 1, "Error: --num-stack cannot be less than 1" assert args.action_repeat >= 1, "Error: --action-repeat cannot be less than 1" assert 0 <= args.port < 65535, "Error: invalid visdom port number {}, ".format(args.port) + \ "port number must be an unsigned 16bit number [0,65535]." assert registered_srl[args.srl_model][0] == SRLType.ENVIRONMENT or args.env in all_models, \ "Error: the environment {} has no srl_model defined in 'srl_models.yaml'. Cannot continue.".format(args.env) # check that all the SRL_model can be run on the environment if registered_srl[args.srl_model][1] is not None: found = False for compatible_class in registered_srl[args.srl_model][1]: if issubclass(compatible_class, registered_env[args.env][0]): found = True break assert found, "Error: srl_model {}, is not compatible with the {} environment.".format( args.srl_model, args.env) assert not(sum([args.simple_continual, args.circular_continual, args.square_continual, args.eight_continual]) \ > 1 and args.env == "OmnirobotEnv-v0"), \ "For continual SRL and RL, please provide only one scenario at the time and use OmnirobotEnv-v0 environment !" assert not(args.algo == "distillation" and (args.teacher_data_folder == '' or args.continuous_actions is True)), \ "For performing policy distillation, make sure use specify a valid teacher dataset and discrete actions !" ENV_NAME = args.env ALGO_NAME = args.algo VISDOM_PORT = args.port EPISODE_WINDOW = args.episode_window MIN_EPISODES_BEFORE_SAVE = args.min_episodes_save CROSS_EVAL = args.perform_cross_evaluation_cc EPISODE_WINDOW_DISTILLATION_WIN = args.eval_episode_window NEW_LR = args.new_lr print("EPISODE_WINDOW_DISTILLATION_WIN: ", EPISODE_WINDOW_DISTILLATION_WIN) if args.no_vis: viz = False algo_class, algo_type, action_type = registered_rl[args.algo] algo = algo_class() ALGO = algo # if callback frequency needs to be changed LOG_INTERVAL = algo.LOG_INTERVAL SAVE_INTERVAL = algo.SAVE_INTERVAL if not args.continuous_actions and ActionType.DISCRETE not in action_type: raise ValueError( args.algo + " does not support discrete actions, please use the '--continuous-actions' " + "(or '-c') flag.") if args.continuous_actions and ActionType.CONTINUOUS not in action_type: raise ValueError( args.algo + " does not support continuous actions, please remove the '--continuous-actions' " + "(or '-c') flag.") env_kwargs["is_discrete"] = not args.continuous_actions printGreen("\nAgent = {} \n".format(args.algo)) env_kwargs["action_repeat"] = args.action_repeat # Random init position for button env_kwargs["random_target"] = args.random_target # If in simple continual scenario, then the target should be initialized randomly. if args.simple_continual is True: env_kwargs["random_target"] = True # Allow up action # env_kwargs["force_down"] = False # allow multi-view env_kwargs['multi_view'] = args.srl_model == "multi_view_srl" parser = algo.customArguments(parser) args = parser.parse_args() args, env_kwargs = configureEnvAndLogFolder(args, env_kwargs, all_models) args_dict = filterJSONSerializableObjects(vars(args)) # Save args with open(LOG_DIR + "args.json", "w") as f: json.dump(args_dict, f) env_class = registered_env[args.env][0] # env default kwargs default_env_kwargs = { k: v.default for k, v in inspect.signature(env_class.__init__).parameters.items() if v is not None } globals_env_param = sys.modules[env_class.__module__].getGlobals() ### HACK way to reset image shape !! globals_env_param['RENDER_HEIGHT'] = img_shape[1] globals_env_param['RENDER_WIDTH'] = img_shape[2] globals_env_param['RELATIVE_POS'] = args.relative_pos super_class = registered_env[args.env][1] # reccursive search through all the super classes of the asked environment, in order to get all the arguments. rec_super_class_lookup = { dict_class: dict_super_class for _, (dict_class, dict_super_class, _, _) in registered_env.items() } while super_class != SRLGymEnv: assert super_class in rec_super_class_lookup, "Error: could not find super class of {}".format(super_class) + \ ", are you sure \"registered_env\" is correctly defined?" super_env_kwargs = { k: v.default for k, v in inspect.signature( super_class.__init__).parameters.items() if v is not None } default_env_kwargs = {**super_env_kwargs, **default_env_kwargs} globals_env_param = { **sys.modules[super_class.__module__].getGlobals(), **globals_env_param } super_class = rec_super_class_lookup[super_class] # Print Variables printYellow("Arguments:") pprint(args_dict) printYellow("Env Globals:") pprint( filterJSONSerializableObjects({ **globals_env_param, **default_env_kwargs, **env_kwargs })) # Save env params saveEnvParams(globals_env_param, {**default_env_kwargs, **env_kwargs}) # Seed tensorflow, python and numpy random generator set_global_seeds(args.seed) # Augment the number of timesteps (when using mutliprocessing this number is not reached) args.num_timesteps = int(1.1 * args.num_timesteps) # Get the hyperparameter, if given (Hyperband) hyperparams = { param.split(":")[0]: param.split(":")[1] for param in args.hyperparam } hyperparams = algo.parserHyperParam(hyperparams) if args.load_rl_model_path is not None: #use a small learning rate print("use a small learning rate: {:f}".format(1.0e-4)) hyperparams["learning_rate"] = lambda f: f * 1.0e-4 # Train the agent if args.load_rl_model_path is not None: algo.setLoadPath(args.load_rl_model_path) algo.train(args, callback, env_kwargs=env_kwargs, train_kwargs=hyperparams)
def main(): parser = argparse.ArgumentParser( description="OpenAI RL Baselines Benchmark", epilog= 'After the arguments are parsed, the rest are assumed to be arguments for' + ' rl_baselines.train') parser.add_argument('--algo', type=str, default='ppo2', help='OpenAI baseline to use', choices=list(registered_rl.keys())) parser.add_argument('--env', type=str, nargs='+', default=["KukaButtonGymEnv-v0"], help='environment ID(s)', choices=list(registered_env.keys())) parser.add_argument('--srl-model', type=str, nargs='+', default=["raw_pixels"], help='SRL model(s) to use', choices=list(registered_srl.keys())) parser.add_argument('--num-timesteps', type=int, default=1e6, help='number of timesteps the baseline should run') parser.add_argument('-v', '--verbose', action='store_true', default=False, help='Display baseline STDOUT') parser.add_argument( '--num-iteration', type=int, default=15, help= 'number of time each algorithm should be run for each unique combination of environment ' + ' and srl-model.') parser.add_argument( '--seed', type=int, default=0, help= 'initial seed for each unique combination of environment and srl-model.' ) parser.add_argument( '--srl-config-file', type=str, default="config/srl_models.yaml", help='Set the location of the SRL model path configuration.') # returns the parsed arguments, and the rest are assumed to be arguments for rl_baselines.train args, train_args = parser.parse_known_args() # Sanity check assert args.num_timesteps >= 1, "Error: --num-timesteps cannot be less than 1" assert args.num_iteration >= 1, "Error: --num-iteration cannot be less than 1" # Removing duplicates and sort srl_models = list(set(args.srl_model)) envs = list(set(args.env)) srl_models.sort() envs.sort() # LOAD SRL models list assert os.path.exists(args.srl_config_file), \ "Error: cannot load \"--srl-config-file {}\", file not found!".format(args.srl_config_file) with open(args.srl_config_file, 'rb') as f: all_models = yaml.load(f) # Checking definition and presence of all requested srl_models valid = True for env in envs: # validated the env definition if env not in all_models: printRed( "Error: 'srl_models.yaml' missing definition for environment {}" .format(env)) valid = False continue # skip to the next env, this one is not valid # checking log_folder for current env missing_log = "log_folder" not in all_models[env] if missing_log: printRed( "Error: 'srl_models.yaml' missing definition for log_folder in environment {}" .format(env)) valid = False # validate each model for the current env definition for model in srl_models: if registered_srl[model][0] == SRLType.ENVIRONMENT: continue # not an srl model, skip to the next model elif model not in all_models[env]: printRed( "Error: 'srl_models.yaml' missing srl_model {} for environment {}" .format(model, env)) valid = False elif (not missing_log) and ( not os.path.exists(all_models[env]["log_folder"] + all_models[env][model])): # checking presence of srl_model path, if and only if log_folder exists printRed( "Error: srl_model {} for environment {} was defined in ". format(model, env) + "'srl_models.yaml', however the file {} it was tagetting does not exist." .format(all_models[env]["log_folder"] + all_models[env][model])) valid = False assert valid, "Errors occured due to malformed 'srl_models.yaml', cannot continue." # check that all the SRL_models can be run on all the environments valid = True for env in envs: for model in srl_models: if registered_srl[model][1] is not None: found = False for compatible_class in registered_srl[model][1]: if issubclass(compatible_class, registered_env[env][0]): found = True break if not found: valid = False printRed( "Error: srl_model {}, is not compatible with the {} environment." .format(model, env)) assert valid, "Errors occured due to an incompatible combination of srl_model and environment, cannot continue." # the seeds used in training the baseline. seeds = list(np.arange(args.num_iteration) + args.seed) if args.verbose: # None here means stdout of terminal for subprocess.call stdout = None else: stdout = open(os.devnull, 'w') printGreen("\nRunning {} benchmarks {} times...".format( args.algo, args.num_iteration)) print("\nSRL-Models:\t{}".format(srl_models)) print("environments:\t{}".format(envs)) print("verbose:\t{}".format(args.verbose)) print("timesteps:\t{}".format(args.num_timesteps)) for model in srl_models: for env in envs: for i in range(args.num_iteration): printGreen( "\nIteration_num={} (seed: {}), Environment='{}', SRL-Model='{}'" .format(i, seeds[i], env, model)) # redefine the parsed args for rl_baselines.train loop_args = [ '--srl-model', model, '--seed', str(seeds[i]), '--algo', args.algo, '--env', env, '--num-timesteps', str(int(args.num_timesteps)), '--srl-config-file', args.srl_config_file ] ok = subprocess.call(['python', '-m', 'rl_baselines.train'] + train_args + loop_args, stdout=stdout) if ok != 0: # throw the error down to the terminal raise ChildProcessError( "An error occured, error code: {}".format(ok))
def main(): # Global variables for callback parser = argparse.ArgumentParser( description= "Evaluation script for distillation from two teacher policies") parser.add_argument('--seed', type=int, default=0, help='random seed (default: 0)') parser.add_argument('--env', type=str, help='environment ID', default='OmnirobotEnv-v0', choices=list(registered_env.keys())) parser.add_argument( '--episode_window', type=int, default=40, help='Episode window for moving average plot (default: 40)') parser.add_argument( '--log-dir-teacher-one', default='/tmp/gym/', type=str, help= 'directory to load an optmimal agent for task 1 (default: /tmp/gym)') parser.add_argument( '--log-dir-teacher-two', default='/tmp/gym/', type=str, help= 'directory to load an optmimal agent for task 2 (default: /tmp/gym)') parser.add_argument( '--log-dir-student', default='/tmp/gym/', type=str, help= 'directory to save the student agent logs and model (default: /tmp/gym)' ) parser.add_argument( '--srl-config-file-one', type=str, default="config/srl_models_one.yaml", help='Set the location of the SRL model path configuration.') parser.add_argument( '--srl-config-file-two', type=str, default="config/srl_models_two.yaml", help='Set the location of the SRL model path configuration.') parser.add_argument( '--epochs-distillation', type=int, default=30, metavar='N', help='number of epochs to train for distillation(default: 30)') parser.add_argument( '--distillation-training-set-size', type=int, default=-1, help='Limit size (number of samples) of the training set (default: -1)' ) parser.add_argument( '--eval-tasks', type=str, nargs='+', default=['cc', 'sqc', 'sc'], help='A cross evaluation from the latest stored model to all tasks') parser.add_argument( '--continual-learning-labels', type=str, nargs=2, metavar=('label_1', 'label_2'), default=argparse.SUPPRESS, help='Labels for the continual learning RL distillation task.') parser.add_argument('--student-srl-model', type=str, default='raw_pixels', choices=list(registered_srl.keys()), help='SRL model to use for the student RL policy') parser.add_argument( '--epochs-teacher-datasets', type=int, default=30, metavar='N', help= 'number of epochs for generating both RL teacher datasets (default: 30)' ) parser.add_argument( '--num-iteration', type=int, default=1, help='number of time each algorithm should be run the eval (N seeds).') parser.add_argument( '--eval-episode-window', type=int, default=400, metavar='N', help= 'Episode window for saving each policy checkpoint for future distillation(default: 100)' ) args, unknown = parser.parse_known_args() if 'continual_learning_labels' in args: assert args.continual_learning_labels[0] in CONTINUAL_LEARNING_LABELS and args.continual_learning_labels[1] \ in CONTINUAL_LEARNING_LABELS, "Please specify a valid Continual learning label to each dataset to be " \ "used for RL distillation !" print(args.continual_learning_labels) assert os.path.exists(args.srl_config_file_one), \ "Error: cannot load \"--srl-config-file {}\", file not found!".format(args.srl_config_file_one) assert os.path.exists(args.srl_config_file_two), \ "Error: cannot load \"--srl-config-file {}\", file not found!".format(args.srl_config_file_two) if not (args.log_dir_teacher_one == "None"): assert os.path.exists(args.log_dir_teacher_one), \ "Error: cannot load \"--srl-config-file {}\", file not found!".format(args.log_dir_teacher_one) assert os.path.exists(args.log_dir_teacher_two), \ "Error: cannot load \"--srl-config-file {}\", file not found!".format(args.srl_config_file_two) teacher_pro = args.log_dir_teacher_one teacher_learn = args.log_dir_teacher_two # The output path generate from the teacher_pro_data = args.continual_learning_labels[0] + '/' teacher_learn_data = args.continual_learning_labels[1] + '/' merge_path = "data/on_policy_merged" print(teacher_pro_data, teacher_learn_data) episodes, policy_path = allPolicy(teacher_learn) rewards_at_episode = {} episodes_to_test = [ e for e in episodes if (int(e) < 2000 and int(e) % 200 == 0) or ( int(e) > 2000 and int(e) % 1000 == 0) ] # generate data from Professional teacher printYellow("\nGenerating on policy for optimal teacher: " + args.continual_learning_labels[0]) if not (args.log_dir_teacher_one == "None"): OnPolicyDatasetGenerator(teacher_pro, args.continual_learning_labels[0] + '_copy/', task_id=args.continual_learning_labels[0], num_eps=args.epochs_teacher_datasets, episode=-1, env_name=args.env) print("Eval on eps list: ", episodes_to_test) for eps in episodes_to_test: student_path = args.log_dir_student printBlue("\n\nEvaluation at episode " + str(eps)) if not (args.log_dir_teacher_one == "None"): # Use a copy of the optimal teacher ok = subprocess.call([ 'cp', '-r', 'data/' + args.continual_learning_labels[0] + '_copy/', 'data/' + teacher_pro_data, '-f' ]) assert ok == 0 time.sleep(2) # Generate data from learning teacher printYellow("\nGenerating on-policy data from the optimal teacher: " + args.continual_learning_labels[1]) OnPolicyDatasetGenerator(teacher_learn, teacher_learn_data, task_id=args.continual_learning_labels[1], episode=eps, num_eps=args.epochs_teacher_datasets, env_name=args.env) if args.log_dir_teacher_one == "None": merge_path = 'data/' + teacher_learn_data ok = subprocess.call( ['cp', '-r', merge_path, 'srl_zoo/data/', '-f']) else: # merge the data mergeData('data/' + teacher_pro_data, 'data/' + teacher_learn_data, merge_path, force=True) ok = subprocess.call( ['cp', '-r', 'data/on_policy_merged/', 'srl_zoo/data/', '-f']) assert ok == 0 time.sleep(2) # Train a policy with distillation on the merged teacher's datasets trainStudent('srl_zoo/' + merge_path, args.continual_learning_labels[1], yaml_file=args.srl_config_file_one, log_dir=args.log_dir_student, srl_model=args.student_srl_model, env_name=args.env, training_size=args.distillation_training_set_size, epochs=args.epochs_distillation) student_path += args.env + '/' + args.student_srl_model + "/distillation/" latest_student_path = max([ student_path + "/" + d for d in os.listdir(student_path) if os.path.isdir(student_path + "/" + d) ], key=os.path.getmtime) + '/' rewards = {} printRed("\nSaving the student at path: " + latest_student_path) for task_label in ["-sc", "-cc"]: rewards[task_label] = [] for seed_i in range(args.num_iteration): printYellow("\nEvaluating student on task: " + task_label + " for seed: " + str(seed_i)) command_line_enjoy_student = [ 'python', '-m', 'replay.enjoy_baselines', '--num-timesteps', '251', '--log-dir', latest_student_path, task_label, "--seed", str(seed_i) ] ok = subprocess.check_output(command_line_enjoy_student) ok = ok.decode('utf-8') str_before = "Mean reward: " str_after = "\npybullet" idx_before = ok.find(str_before) + len(str_before) idx_after = ok.find(str_after) seed_reward = float(ok[idx_before:idx_after]) rewards[task_label].append(seed_reward) print("rewards at eps ", eps, ": ", rewards) rewards_at_episode[eps] = rewards print("All rewards: ", rewards_at_episode) json_dict = json.dumps(rewards_at_episode) json_dict_name = \ args.log_dir_student + "/reward_at_episode_" + datetime.datetime.now().strftime("%y-%m-%d_%Hh%M_%S") + '.json' f = open(json_dict_name, "w") f.write(json_dict) f.close() printRed("\nSaving the evalation at path: " + json_dict_name)
def main(): # Global variables for callback global ENV_NAME, ALGO, ALGO_NAME, LOG_INTERVAL, VISDOM_PORT, viz global SAVE_INTERVAL, EPISODE_WINDOW, MIN_EPISODES_BEFORE_SAVE parser = argparse.ArgumentParser(description="Train script for RL algorithms") parser.add_argument('--algo', default='ppo2', choices=list(registered_rl.keys()), help='RL algo to use', type=str) parser.add_argument('--env', type=str, help='environment ID', default='KukaButtonGymEnv-v0', choices=list(registered_env.keys())) parser.add_argument('--seed', type=int, default=0, help='random seed (default: 0)') parser.add_argument('--episode_window', type=int, default=40, help='Episode window for moving average plot (default: 40)') parser.add_argument('--log-dir', default='/tmp/gym/', type=str, help='directory to save agent logs and model (default: /tmp/gym)') parser.add_argument('--num-timesteps', type=int, default=int(1e6)) parser.add_argument('--srl-model', type=str, default='raw_pixels', choices=list(registered_srl.keys()), help='SRL model to use') parser.add_argument('--num-stack', type=int, default=1, help='number of frames to stack (default: 1)') parser.add_argument('--action-repeat', type=int, default=1, help='number of times an action will be repeated (default: 1)') parser.add_argument('--port', type=int, default=8097, help='visdom server port (default: 8097)') parser.add_argument('--no-vis', action='store_true', default=False, help='disables visdom visualization') parser.add_argument('--shape-reward', action='store_true', default=False, help='Shape the reward (reward = - distance) instead of a sparse reward') parser.add_argument('-c', '--continuous-actions', action='store_true', default=False) parser.add_argument('-joints', '--action-joints', action='store_true', default=False, help='set actions to the joints of the arm directly, instead of inverse kinematics') parser.add_argument('-r', '--random-target', action='store_true', default=False, help='Set the button to a random position') parser.add_argument('--srl-config-file', type=str, default="config/srl_models.yaml", help='Set the location of the SRL model path configuration.') parser.add_argument('--hyperparam', type=str, nargs='+', default=[]) parser.add_argument('--min-episodes-save', type=int, default=100, help="Min number of episodes before saving best model") parser.add_argument('--latest', action='store_true', default=False, help='load the latest learned model (location:srl_zoo/logs/DatasetName/)') parser.add_argument('--load-rl-model-path', type=str, default=None, help="load the trained RL model, should be with the same algorithm type") # Ignore unknown args for now args, unknown = parser.parse_known_args() env_kwargs = {} # LOAD SRL models list assert os.path.exists(args.srl_config_file), \ "Error: cannot load \"--srl-config-file {}\", file not found!".format(args.srl_config_file) with open(args.srl_config_file, 'rb') as f: all_models = yaml.load(f) # Sanity check assert args.episode_window >= 1, "Error: --episode_window cannot be less than 1" assert args.num_timesteps >= 1, "Error: --num-timesteps cannot be less than 1" assert args.num_stack >= 1, "Error: --num-stack cannot be less than 1" assert args.action_repeat >= 1, "Error: --action-repeat cannot be less than 1" assert 0 <= args.port < 65535, "Error: invalid visdom port number {}, ".format(args.port) + \ "port number must be an unsigned 16bit number [0,65535]." assert registered_srl[args.srl_model][0] == SRLType.ENVIRONMENT or args.env in all_models, \ "Error: the environment {} has no srl_model defined in 'srl_models.yaml'. Cannot continue.".format(args.env) # check that all the SRL_model can be run on the environment if registered_srl[args.srl_model][1] is not None: found = False for compatible_class in registered_srl[args.srl_model][1]: if issubclass(compatible_class, registered_env[args.env][0]): found = True break assert found, "Error: srl_model {}, is not compatible with the {} environment.".format(args.srl_model, args.env) ENV_NAME = args.env ALGO_NAME = args.algo VISDOM_PORT = args.port EPISODE_WINDOW = args.episode_window MIN_EPISODES_BEFORE_SAVE = args.min_episodes_save if args.no_vis: viz = False algo_class, algo_type, action_type = registered_rl[args.algo] algo = algo_class() ALGO = algo # if callback frequency needs to be changed LOG_INTERVAL = algo.LOG_INTERVAL SAVE_INTERVAL = algo.SAVE_INTERVAL if not args.continuous_actions and ActionType.DISCRETE not in action_type: raise ValueError(args.algo + " does not support discrete actions, please use the '--continuous-actions' " + "(or '-c') flag.") if args.continuous_actions and ActionType.CONTINUOUS not in action_type: raise ValueError(args.algo + " does not support continuous actions, please remove the '--continuous-actions' " + "(or '-c') flag.") env_kwargs["is_discrete"] = not args.continuous_actions printGreen("\nAgent = {} \n".format(args.algo)) env_kwargs["action_repeat"] = args.action_repeat # Random init position for button env_kwargs["random_target"] = args.random_target # Allow up action # env_kwargs["force_down"] = False # allow multi-view env_kwargs['multi_view'] = args.srl_model == "multi_view_srl" parser = algo.customArguments(parser) args = parser.parse_args() args, env_kwargs = configureEnvAndLogFolder(args, env_kwargs, all_models) args_dict = filterJSONSerializableObjects(vars(args)) # Save args with open(LOG_DIR + "args.json", "w") as f: json.dump(args_dict, f) env_class = registered_env[args.env][0] # env default kwargs default_env_kwargs = {k: v.default for k, v in inspect.signature(env_class.__init__).parameters.items() if v is not None} globals_env_param = sys.modules[env_class.__module__].getGlobals() super_class = registered_env[args.env][1] # reccursive search through all the super classes of the asked environment, in order to get all the arguments. rec_super_class_lookup = {dict_class: dict_super_class for _, (dict_class, dict_super_class, _, _) in registered_env.items()} while super_class != SRLGymEnv: assert super_class in rec_super_class_lookup, "Error: could not find super class of {}".format(super_class) + \ ", are you sure \"registered_env\" is correctly defined?" super_env_kwargs = {k: v.default for k, v in inspect.signature(super_class.__init__).parameters.items() if v is not None} default_env_kwargs = {**super_env_kwargs, **default_env_kwargs} globals_env_param = {**sys.modules[super_class.__module__].getGlobals(), **globals_env_param} super_class = rec_super_class_lookup[super_class] # Print Variables printYellow("Arguments:") pprint(args_dict) printYellow("Env Globals:") pprint(filterJSONSerializableObjects({**globals_env_param, **default_env_kwargs, **env_kwargs})) # Save env params saveEnvParams(globals_env_param, {**default_env_kwargs, **env_kwargs}) # Seed tensorflow, python and numpy random generator set_global_seeds(args.seed) # Augment the number of timesteps (when using mutliprocessing this number is not reached) args.num_timesteps = int(1.1 * args.num_timesteps) # Get the hyperparameter, if given (Hyperband) hyperparams = {param.split(":")[0]: param.split(":")[1] for param in args.hyperparam} hyperparams = algo.parserHyperParam(hyperparams) if args.load_rl_model_path is not None: #use a small learning rate print("use a small learning rate: {:f}".format(1.0e-4)) hyperparams["learning_rate"] = lambda f: f * 1.0e-4 # Train the agent if args.load_rl_model_path is not None: algo.setLoadPath(args.load_rl_model_path) algo.train(args, callback, env_kwargs=env_kwargs, train_kwargs=hyperparams)