def main_sl(): """ Applies the PnnMethod in a SL Setting. """ parser = ArgumentParser(description=__doc__, add_dest_to_option_strings=False) # Add arguments for the Setting # TODO: PNN is coded for the DomainIncrementalSetting, where the action space # is the same for each task. # parser.add_arguments(DomainIncrementalSetting, dest="setting") parser.add_arguments(TaskIncrementalSLSetting, dest="setting") # TaskIncrementalSLSetting.add_argparse_args(parser, dest="setting") Config.add_argparse_args(parser, dest="config") # Add arguments for the Method: PnnMethod.add_argparse_args(parser, dest="method") args = parser.parse_args() # setting: TaskIncrementalSLSetting = args.setting setting: TaskIncrementalSLSetting = TaskIncrementalSLSetting.from_argparse_args( # setting: DomainIncrementalSetting = DomainIncrementalSetting.from_argparse_args( args, dest="setting", ) config: Config = Config.from_argparse_args(args, dest="config") method: PnnMethod = PnnMethod.from_argparse_args(args, dest="method") method.config = config results = setting.apply(method, config=config) print(results.summary()) return results
def main_rl(): """ Applies the PnnMethod in a RL Setting. """ parser = ArgumentParser(description=__doc__, add_dest_to_option_strings=False) Config.add_argparse_args(parser, dest="config") PnnMethod.add_argparse_args(parser, dest="method") # Haven't tested with observe_state_directly=False # it run but I don't know if it converge setting = TaskIncrementalRLSetting( dataset="cartpole", observe_state_directly=True, nb_tasks=2, train_task_schedule={ 0: { "gravity": 10, "length": 0.3 }, 1000: { "gravity": 10, "length": 0.5 }, }, ) args = parser.parse_args() config: Config = Config.from_argparse_args(args, dest="config") method: PnnMethod = PnnMethod.from_argparse_args(args, dest="method") method.config = config # 2. Creating the Method # method = ImproveMethod() # 3. Applying the method to the setting: results = setting.apply(method, config=config) print(results.summary()) print(f"objective: {results.objective}") return results