def run(root, path_script): experiment = Experiment(root=root) factory = experiment.get_factory() params = experiment.get_params() # callbacks_factory = experiment.get_factory(source_function='get_ens_callbacks') params['factory_args']['learner_args']['dump_path'] = root # params['factory_args']['dump_path'] = root Model = experiment.get_model_class() experiment.document_script(path_script, overwrite=params['overwrite']) env = MultiInstanceGym(**params['env_args']) params['factory_args']['model_args'][ 'in_nodes'] = env.observation_space.shape[0] params['factory_args']['model_args']['out_nodes'] = env.action_space.n params['factory_args']['env'] = env dqn_player = get_player(key=params['player_type']) selection_strategy = get_selection_strategy( params['selection_strategy'], params.get('selection_args', {})) with with_experiment(experiment=experiment, overwrite=params['overwrite']): memory = get_memory(params['memory_type'], params['memory_args']) params['factory_args']['learner_args']['memory'] = memory learner = DQNEnsemble( model_class=Model, trainer_factory=factory, memory=memory, env=env, player=dqn_player, selection_strategy=selection_strategy, trainer_args=params['factory_args'], n_model=params['n_learner'], dump_path=root, callbacks=[ rcb.EpisodeUpdater(**params.get('memory_update', {}), frequency=5), rcb.UncertaintyUpdater(hat=hat.EntropyHat()), cb.Checkpointer(frequency=10), rcb.EnvironmentEvaluator( env=env, n_evaluations=10, action_selector=sp.GreedyValueSelection( post_pipeline=[hat.EnsembleHat()]), metrics={ 'det_val_reward_mean': np.mean, 'deter_val_reward_std': np.std }, frequency=10, epoch_name='det_val_epoch'), rcb.EnvironmentEvaluator( env=env, n_evaluations=10, action_selector=get_selection_strategy( params['eval_selection_strategy'], params.get('selection_args', {})), metrics={ 'prob_val_reward_mean': np.mean, 'prob_val_reward_std': np.std }, frequency=10, epoch_name='prob_val_epoch'), rcb.EnsembleRewardPlotter(frequency=10, metrics={ 'det_val_reward_mean': 'det_val_epoch', 'prob_val_reward_mean': 'prob_val_epoch', }) ]) learner.fit(**params['fit'])
def run(root, path_script): print(root, path_script) experiment = Experiment(root=root) factory = experiment.get_factory() params = experiment.get_params() params['factory_args']['learner_args']['dump_path'] = root Model = experiment.get_model_class() experiment.document_script(path_script, overwrite=params['overwrite']) # env = MultiInstanceGym(**params['factory_args']['env_args']) env = TorchGym(params['factory_args']['env_args']['env_name']) params['factory_args']['model_args'][ 'in_nodes'] = env.observation_space.shape[0] params['factory_args']['model_args']['out_nodes'] = env.action_space.n dqn_player = DQNPlayer() # selection_strategy = sp.AdaptiveQActionSelectionEntropy(warm_up=0, # post_pipeline=[EnsembleHatStd()]) # selection_strategy = sp.QActionSelection(post_pipeline=[EnsembleHat()]) selection_strategy = sp.EpsilonGreedyActionSelection( action_space=[0, 1, 2, 3], post_pipeline=[EnsembleHat()]) with with_experiment(experiment=experiment, overwrite=params['overwrite']): memory = Memory(**params["factory_args"]['memory_args']) params['factory_args']['learner_args']['memory'] = memory learner = DQNEnsemble( model_class=Model, trainer_factory=factory, memory=memory, env=env, player=dqn_player, selection_strategy=selection_strategy, trainer_args=params['factory_args'], n_model=params['n_learner'], callbacks=[ rcb.EpisodeUpdater(**params.get('memory_update', {})), cb.Checkpointer(frequency=1), # rcb.UncertaintyUpdater(), rcb.EnvironmentEvaluator( env=TorchGym( params['factory_args']['env_args']['env_name']), n_evaluations=10, action_selector=sp.GreedyValueSelection( post_pipeline=[EnsembleHat()]), metrics={ 'det_val_reward_mean': np.mean, 'deter_val_reward_std': np.std }, frequency=1, epoch_name='det_val_epoch'), rcb.EnvironmentEvaluator(env=TorchGym( params['factory_args']['env_args']['env_name']), n_evaluations=10, action_selector=selection_strategy, metrics={ 'prob_val_reward_mean': np.mean, 'prob_val_reward_std': np.std }, frequency=1, epoch_name='prob_val_epoch'), rcb.EnsembleRewardPlotter( metrics={ 'det_val_reward_mean': 'det_val_epoch', 'prob_val_reward_mean': 'prob_val_epoch', }), ]) # learner.load_checkpoint(path=f'{root}/checkpoint', tag='checkpoint') learner.fit(**params['fit'])