def train(env_name): ModelCatalog.register_custom_model("masked_actions_model", MaskedActionsCNN) model_config = { "custom_model": "masked_actions_model", "conv_filters": [[16, [2, 2], 1], [32, [2, 2], 1]], "conv_activation": "elu", "fcnet_hiddens": [128], "fcnet_activation": "elu", } tune_config = { "num_workers": 24, "num_gpus": 1, "batch_mode": "complete_episodes", "model": model_config, "env": env_name, "lr": 0.001, "multiagent": { "policies": policies, "policy_mapping_fn": policy_mapping, }, "framework": "tf" } trainer = DQNTrainer(env=env_name, config=tune_config) for i in range(1000): print("== Iteration {}==".format(i)) results = trainer.train() pretty_print(results) checkpoint = trainer.save() print("\nCheckpoint saved at {}\n".format(checkpoint))
def run(self, checkpoint=None, iters=None, dry_run=False): if self._cli_args is None: self.load_cli_args() if not self._cli_args['debug']: try: ray.init(address='auto') except ConnectionError: ray.init() print("Running in single node!") else: print("Running in cluster!") else: ray.init(local_mode=True) self._trainer = ppo.PPOTrainer(config=self._all_config) if checkpoint is not None: self._trainer.restore(checkpoint) if not dry_run: for _ in range(self._iterations if iters is None else int(iters)): res = self._trainer.train() if (_ + 1) % 10 == 0: print(pretty_print(res)) if (_ + 1) % 100 == 0: self._checkpoint_path = self._trainer.save() print(f"Model saved at {self._checkpoint_path}")
def trainDqn(numIter): """ train """ ray.shutdown() ray.init() config = createConfig() trainer = dqn.DQNTrainer(config=config, env=HiLoPricingEnv) for i in range(numIter): print("\n**** next iteration " + str(i)) HiLoPricingEnv.count = 0 result = trainer.train() print(pretty_print(result)) print("env reset count " + str(HiLoPricingEnv.count)) policy = trainer.get_policy() weights = policy.get_weights() #print("policy weights") #print(weights) model = policy.model #summary = model.base_model.summary() #print("model summary") #print(weights) return trainer
def game_train(): # config = a3c.a2c.A2C_DEFAULT_CONFIG.copy() config = ppo.appo.DEFAULT_CONFIG.copy() config["num_gpus"] = 1 config["num_workers"] = 12 # config["lambda"] = 1.0 config["model"] = model_config # trainer = ppo.PPOTrainer(env="pom", config=config) config["lr"] = 0.0001 # config["lr_schedule"] = [[0, 5e-4], [2000000, 5e-5], [4000000, 1e-5], [6000000, 1e-6], [8000000, 1e-7]] # config["vf_clip_param"] = 0.5 config["grad_clip"] = 0.5 config["use_gae"] = False # trainer = a3c.a2c.A2CTrainer(env="pom", config=config) trainer = ppo.appo.APPOTrainer(env="pom", config=config) # Can optionally call trainer.restore(path) to load a checkpoint. for i in range(30000): result = trainer.train() print(pretty_print(result)) del result if i % 100 == 0: checkpoint = trainer.save() print("checkpoint saved at", checkpoint) del checkpoint
def update_last_result(self, result, terminate=False): result.update(trial_id=self.trial_id, done=terminate) if self.experiment_tag: result.update(experiment_tag=self.experiment_tag) if self.verbose and (terminate or time.time() - self.last_debug > DEBUG_PRINT_INTERVAL): print("Result for {}:".format(self)) print(" {}".format(pretty_print(result).replace("\n", "\n "))) self.last_debug = time.time() self.set_location(Location(result.get("node_ip"), result.get("pid"))) self.last_result = result self.last_update_time = time.time() self.result_logger.on_result(self.last_result) for metric, value in flatten_dict(result).items(): if isinstance(value, Number): if metric not in self.metric_analysis: self.metric_analysis[metric] = { "max": value, "min": value, "last": value } else: self.metric_analysis[metric]["max"] = max( value, self.metric_analysis[metric]["max"]) self.metric_analysis[metric]["min"] = min( value, self.metric_analysis[metric]["min"]) self.metric_analysis[metric]["last"] = value
def dqn_train(config, reporter): # Instantiate a trainer cfg = { # Max num timesteps for annealing schedules. Exploration is annealed from # 1.0 to exploration_fraction over this number of timesteps scaled by # exploration_fraction "schedule_max_timesteps": 1000000, # Minimum env steps to optimize for per train call. This value does # not affect learning, only the length of iterations. "timesteps_per_iteration": 1000, # Fraction of entire training period over which the exploration rate is # annealed "exploration_fraction": 0.1, # Final value of random action probability "exploration_final_eps": 0.02, "n_step": 3, "buffer_size": 500000, # "sample_batch_size" : 32, # "train_batch_size" : 128, # "learning_starts" : 5000, # "target_network_update_freq": 5000, # "num_workers" : NUM_WORKERS, # "per_worker_exploration" : True, # "worker_side_prioritization": True, # "min_iter_time_s" : 1, } trainer = DQNTrainer(config={**config, **cfg}) while True: result = trainer.train() # Executes one training step print(pretty_print(result)) reporter(**result) # notifies TrialRunner
def _train(self): # improve the Adversary policy print("-- Adversary Training --") print(pretty_print(self.adv_trainer.train())) # swap weights to synchronize self.agent_trainer.set_weights(self.adv_trainer.get_weights(["adversary0"])) # improve the Agent policy print("-- Agent Training --") output = self.agent_trainer.train() print(pretty_print(output)) # swap weights to synchronize self.adv_trainer.set_weights(self.agent_trainer.get_weights(["agent"])) return output
def yaniv_eval(trainer, eval_workers): print("\n\n\n************** EVALUATION **************") t = YanivTournament(env_config, [trainer]) res = t.run(eval_num) print(pretty_print(res), "\n\n\n") eval_vs = "eval_rules_" metrics = { eval_vs + "draw_rate": res["game"]["avg_draws"], eval_vs + "avg_roundlen": res["game"]["avg_roundlen"], eval_vs + "win_rate": res["player"]["player_0"]["avg_wins"], eval_vs + "assaf_rate": res["player"]["player_0"]["avg_assafs"], eval_vs + "self_avg_losing_score": res["player"]["player_0"]["avg_losing_score"], eval_vs + "oppt_avg_losing_score": np.mean([ val["avg_losing_score"] for key, val in res["player"].items() if key != "player_0" ]), eval_vs + "oppt_assaf_rate": np.mean([ val["avg_assafs"] for key, val in res["player"].items() if key != "player_0" ]), } return metrics
def _train(self): # improve the Adversary policy print("-- Adversary Training --") original_weight = self.adv_trainer.get_weights( ["adversary"])['adversary']['adversary/fc_1/kernel'][0, 0] print(pretty_print(self.adv_trainer.train())) first_weight = self.adv_trainer.get_weights( ["adversary"])['adversary']['adversary/fc_1/kernel'][0, 0] # Check that the weights are updating after training assert original_weight != first_weight, 'The weight hasn\'t changed after training what gives' # swap weights to synchronize self.agent_trainer.set_weights( self.adv_trainer.get_weights(["adversary"])) # improve the Agent policy print("-- Agent Training --") output = self.agent_trainer.train() # Assert that the weight hasn't changed but it has new_weight = self.agent_trainer.get_weights( ["adversary"])['adversary']['adversary/fc_1/kernel'][0, 0] # Check that the adversary is not being trained when the agent trainer is training assert first_weight == new_weight, 'The weight of the adversary matrix has changed but it shouldnt have been updated!' # swap weights to synchronize self.adv_trainer.set_weights(self.agent_trainer.get_weights(["agent"])) return output
def run_policy(trainer): start = int(round(time.time())) while True: elapsed = int(round(time.time())) - start if elapsed > runtime: break result = trainer.train() print(pretty_print(result))
def ray_dqn_learn(num_eps, agent, c_freq=10): total_eps = 0 while total_eps <= num_eps: print("{}/{}".format(total_eps, num_eps)) train_result = agent.train() total_eps += train_result['episodes_this_iter'] logging.debug(pretty_print(train_result)) if total_eps % c_freq == 0: agent.save()
def update_last_result(self, result, terminate=False): if terminate: result = result._replace(done=True) if self.verbose and (terminate or time.time() - self.last_debug > DEBUG_PRINT_INTERVAL): print("TrainingResult for {}:".format(self)) print(" {}".format(pretty_print(result).replace("\n", "\n "))) self.last_debug = time.time() self.last_result = result self.result_logger.on_result(self.last_result)
def update_last_result(self, result, terminate=False): result.update(trial_id=self.trial_id, done=terminate) if self.verbose and (terminate or time.time() - self.last_debug > DEBUG_PRINT_INTERVAL): print("Result for {}:".format(self)) print(" {}".format(pretty_print(result).replace("\n", "\n "))) self.last_debug = time.time() self.last_result = result self.last_update_time = time.time() self.result_logger.on_result(self.last_result)
def update_last_result(self, result, terminate=False): if terminate: result.update(done=True) if self.verbose and (terminate or time.time() - self.last_debug > DEBUG_PRINT_INTERVAL): print("Result for {}:".format(self)) print(" {}".format(pretty_print(result).replace("\n", "\n "))) self.last_debug = time.time() self.last_result = result self.last_update_time = time.time() self.result_logger.on_result(self.last_result)
def _train(self): self.set_theta_to_evaluators() # optimize policy under estimated reward for train_iter in range(self.num_train): # collect samples with new reward fnc samples = self.sample(self.train_batch_size) for sample in samples.policy_batches.values(): sample.shuffle() # train local based on samples self.train_policy_by_samples(samples) res = collect_metrics(self.local_evaluator, self.remote_evaluators) pretty_print(res) samples = self.sample(self.train_batch_size) norm = self.update_theta(samples, self.theta_lr) res["custom_metrics"]["theta_norm"] = norm return res
def trainIncr(path, numIter): """ load trainer from checkpoint and incremental training """ trainer = loadTrainer(path) for i in range(numIter): print("\n**** next iteration " + str(i)) HiLoPricingEnv.count = 0 result = trainer.train() print(pretty_print(result)) print("env reset count " + str(HiLoPricingEnv.count)) return trainer
def main_ppo(): import ray.rllib.agents.ppo as ppo wandb.init(project='duocopter', sync_tensorboard=True) ray.init() env_config = { 'copter_weight_kg': 0.5, 'g': 9.81, 'max_thrust_N': 2*9.81, 'max_wattage_W': 2*350, # TODO: Use power curve 'k1_m': 0.01, # TODO: Change 'k2_m': 24E-3, 'theta_deg': 0, 'dyn_fric_coeff': 0.14, 'cart_height_m': 0.2104, 'thrust_centerline_distance_m': 0.01, #TODO: Change 'dt': 5E-4, 'max_height_m': 1.44, 'sampling_rate_hz': 20, 'log': True } config = ppo.DEFAULT_CONFIG.copy() config['num_workers'] = 10 config['env_config'] = env_config config['lambda'] = 0.9 config['lr'] = 5e-5 config['rollout_fragment_length'] = 500 config['model']['fcnet_hiddens'] = [64, 64] trainer = ppo.PPOTrainer(config=config, env=SimEnv) for i in range(300): result = trainer.train() print(pretty_print(result)) env = SimEnv(env_config) state = env.reset() done = False ep_reward = 0 while not done: thrust = trainer.compute_action(state, explore=False) state, rw, done, _ = env.step(thrust) ep_reward += rw print(env.calc_rms()) env.plot() checkpoint = trainer.save() print(checkpoint)
def main(restore_path=None, num_gpus=0, num_workers=1, num_training_iterations=100, checkpoint_freq=100): env_config = { 'player_num': int(os.getenv('PLAYER_NUM', 1)), 'game_config': { 'await_connection_time': 120, 'server_address': 'server', 'pub_socket': str(os.getenv("PUB_SOCKET", "5555")), 'sub_socket': '5563', 'unit_config': { 1: 33, 2: 33, 3: 34 } }, } logger.info('Starting game for player {}'.format(os.getenv('PLAYER_NUM'))) ray.init(temp_dir='./results') config = ppo.DEFAULT_CONFIG.copy() config['num_gpus'] = num_gpus config['num_workers'] = num_workers config['monitor'] = False # config['evaluation_interval'] = 100 # config['evaluation_num_episodes'] = 1 # config['evaluation_config'] = { # # } config['env_config'] = env_config # register_env('everglades', Everglades) # config['num_cpus_per_worker'] = 1 trainer = ppo.PPOTrainer( env='everglades', config=config, ) if restore_path is not None: trainer.restore(restore_path) for i in range(num_training_iterations): result = trainer.train() logger.info(pretty_print(result)) if i % checkpoint_freq == 0: checkpoint = trainer.save() logger.info(f'Saving checkpoint at {i}') trainer.save() logger.info(f'Training over. Saving checkpoint')
def train(num_iters): trainer = PPOTrainer( env='SUMOEnv-v0', config={ 'model': { "conv_filters": [ [32, [4, 4], 8], [64, [2, 2], 4], ], }, 'multiagent': { 'policy_graphs': { 'cluster_648538736_648538737': (PPOPolicyGraph, Box(low=0., high=1., shape=(32, 32, 1)), Discrete(n=5), {}), '49228579': (PPOPolicyGraph, Box(low=0., high=1., shape=(32, 32, 1)), Discrete(n=4), {}), 'cluster_2511020106_49297289': (PPOPolicyGraph, Box(low=0., high=1., shape=(32, 32, 1)), Discrete(n=4), {}), 'cluster_298135838_49135231': (PPOPolicyGraph, Box(low=0., high=1., shape=(32, 32, 1)), Discrete(n=3), {}), 'cluster_290051904_49145925': (PPOPolicyGraph, Box(low=0., high=1., shape=(32, 32, 1)), Discrete(n=5), {}), 'cluster_290051912_298136030_648538909': (PPOPolicyGraph, Box(low=0., high=1., shape=(32, 32, 1)), Discrete(n=3), {}), 'cluster_2511020102_2511020103_290051922_298135886': (PPOPolicyGraph, Box(low=0., high=1., shape=(32, 32, 1)), Discrete(n=4), {}), }, 'policy_mapping_fn': function(lambda agent_id: agent_id), }, 'callbacks': { 'on_episode_end': function(on_episode_end), }, # 'num_workers': 4, # 'num_gpus_per_worker': 0.25, # All workers on a single GPU # 'timesteps_per_iteration': 16000, }) for i in range(num_iters): print(f'== Iteration {i}==') print(pretty_print(trainer.train()))
def update_config(config: dict, config_updates: dict): logger.warning("Updating default config values by: \n {}".format(pretty_print(config_updates))) recursive_dict_update(config, config_updates) # If the seed and experiment_name are changed their copies in env config should be updated as well # (In the config.yml file this is done by anchors and alias indicators config['env_config'].update({'seed': config['seed'], 'experiment_name': config['experiment_name']}) if 'mode' in config['env_config'].keys(): if config['env_config']['mode'] == 'debug': logger.warning( "Env_config.mode is 'debug', some hyperparameters will be overwritten by: \n {}".format( pretty_print(config["debug_hparams"]))) config["rllib_config"].update(config["debug_hparams"]["rllib_config"]) config["ray_init_config"].update(config["debug_hparams"]["ray_init_config"]) default_config = load_config(update_algo_hparams_from_algo_conf_file=False) if 'inference_hparams' not in config.keys(): config['inference_hparams'] = default_config['inference_hparams'] elif 'explore' not in config["inference_hparams"]["rllib_config"]: config["inference_hparams"]["rllib_config"]['explore'] = \ default_config["inference_hparams"]["rllib_config"]['explore'] # Setting explore to what is set in the default config (false) is important, because in many older trainings # this key is missing, in which case it is treated as true by rllib. if config['env_config']['mode'] == 'inference': logger.warning( "Env_config.mode is 'inference', some hyperparameters will be overwritten by: \n {}".format( pretty_print(config["inference_hparams"]))) config["rllib_config"].update(config["inference_hparams"]["rllib_config"]) config["ray_init_config"].update(config["inference_hparams"]["ray_init_config"]) assert config['env_config']['mode'] in ['train', 'inference', 'debug'] # For loaded config dups the env config is replicated in rllib_config if 'env_config' in config['rllib_config'].keys(): config['rllib_config']['env_config'].update(config['env_config'])
def test_agent_with_mask(): initialize_ray(test_mode=True, local_mode=False) ckpt = "~/ray_results/0810-20seeds/PPO_BipedalWalker-v2_0_seed=20_2019" \ "-08-10_16-54-37xaa2muqm/checkpoint_469/checkpoint-469" # ckpt = None ret_list = [] agent = restore_agent_with_mask("PPO", ckpt, "BipedalWalker-v2") # agent.compute_action(np.ones(24)) for i in range(10): test_reward = agent.train() print(pretty_print(test_reward)) ret_list.append(test_reward) print("Test end") agent.get_policy().set_default({ 'fc_1_mask': np.ones([ 256, ]), 'fc_2_mask': np.ones([ 256, ]) }) for i in range(10): test_reward2 = agent.train() print(pretty_print(test_reward2)) ret_list.append(test_reward2) print("Test2 end") return test_reward, test_reward2, ret_list
def train(self, environment, env_config): ray.init() trainer = ppo.PPOTrainer( env=environment, config={"env_config": env_config}, ) while True: results = trainer.train() episodes = results.get("episodes_total") print(f"\nEpisodes: {episodes}") print(f'Reward mean: {results.get("episode_reward_mean")}\n') if episodes >= 1000: break print("\n\n============") print(pretty_print(results))
def log_result(self, trial: "Trial", result: Dict, error: bool = False): done = result.get("done", False) is True last_print = self._last_print[trial] if done and trial not in self._completed_trials: self._completed_trials.add(trial) if has_verbosity(Verbosity.V3_TRIAL_DETAILS) and ( done or error or time.time() - last_print > DEBUG_PRINT_INTERVAL ): print("Result for {}:".format(trial)) print(" {}".format(pretty_print(result).replace("\n", "\n "))) self._last_print[trial] = time.time() elif has_verbosity(Verbosity.V2_TRIAL_NORM) and ( done or error or time.time() - last_print > DEBUG_PRINT_INTERVAL ): info = "" if done: info = " This trial completed." metric_name = self._metric or "_metric" metric_value = result.get(metric_name, -99.0) print_result_str = self._print_result(result) self._last_result_str[trial] = print_result_str error_file = os.path.join(trial.logdir, "error.txt") if error: message = ( f"The trial {trial} errored with " f"parameters={trial.config}. " f"Error file: {error_file}" ) elif self._metric: message = ( f"Trial {trial} reported " f"{metric_name}={metric_value:.2f} " f"with parameters={trial.config}.{info}" ) else: message = ( f"Trial {trial} reported " f"{print_result_str} " f"with parameters={trial.config}.{info}" ) print(message) self._last_print[trial] = time.time()
def update_last_result(self, result, terminate=False): result.update(trial_id=self.trial_id, done=terminate) if self.experiment_tag: result.update(experiment_tag=self.experiment_tag) if self.verbose and (terminate or time.time() - self.last_debug > DEBUG_PRINT_INTERVAL): print("Result for {}:".format(self)) print(" {}".format(pretty_print(result).replace("\n", "\n "))) self.last_debug = time.time() self.set_location(Location(result.get("node_ip"), result.get("pid"))) self.last_result = result self.last_update_time = time.time() self.result_logger.on_result(self.last_result) for metric, value in flatten_dict(result).items(): if isinstance(value, Number): if metric not in self.metric_analysis: self.metric_analysis[metric] = { "max": value, "min": value, "avg": value, "last": value } self.metric_n_steps[metric] = {} for n in self.n_steps: key = "last-{:d}-avg".format(n) self.metric_analysis[metric][key] = value # Store n as string for correct restore. self.metric_n_steps[metric][str(n)] = deque([value], maxlen=n) else: step = result["training_iteration"] or 1 self.metric_analysis[metric]["max"] = max( value, self.metric_analysis[metric]["max"]) self.metric_analysis[metric]["min"] = min( value, self.metric_analysis[metric]["min"]) self.metric_analysis[metric]["avg"] = 1 / step * ( value + (step - 1) * self.metric_analysis[metric]["avg"]) self.metric_analysis[metric]["last"] = value for n in self.n_steps: key = "last-{:d}-avg".format(n) self.metric_n_steps[metric][str(n)].append(value) self.metric_analysis[metric][key] = sum( self.metric_n_steps[metric][str(n)]) / len( self.metric_n_steps[metric][str(n)])
def train(name, ray_config, debug=False): """ Trains sam Parameters ---------- name: name of yaml file ray_config: ray configuration debug: whether to test in editor Returns ------- """ ray.init() trainer_class = get_trainable_cls(ray_config['run']) default_config = trainer_class._default_config.copy() config = merge_dicts(default_config, ray_config['config']) now = datetime.datetime.now().strftime('%Y%m%d-%Hh%M') run = ray_config['run'] model_name = f'{name}_{now}' print(f'\33]0;{model_name} - {name}\a', end='', flush=True) if debug: config['num_workers'] = 0 config['num_envs_per_worker'] = 1 # config['train_batch_size'] = 10 config['env_config']['log_every'] = 2000 trainer = trainer_class(config=config) policy = trainer.get_policy() model = policy.model print(model) for i in range(10): res = trainer.train() print(pretty_print(res)) else: tune.run( run, name=model_name, # stop=ray_config['stop'], local_dir='results', config=config, checkpoint_at_end=True, verbose=2, # restore=RESTORE_PATH, checkpoint_freq=10) ray.shutdown()
def train_yseq_hvd(workers_per_node, epochs, **config): from zoo.orca.learn.pytorch import Estimator estimator = Estimator.from_torch(model=model_creator, optimizer=optimizer_creator, loss=loss_creator, workers_per_node=workers_per_node, config=config) stats = estimator.fit(train_data_creator, epochs=epochs) for s in stats: print(pretty_print(s)) val_stats = estimator.evaluate(val_data_creator) val_loss = val_stats['val_loss'] # retrieve the model yseq = estimator.get_model() estimator.shutdown() return yseq, val_loss
def train_yseq_dist(num_workers, epochs, **config): from zoo.orca.learn.pytorch.pytorch_trainer import PyTorchTrainer trainer = PyTorchTrainer(model_creator=model_creator, data_creator=data_creator, optimizer_creator=optimizer_creator, loss_creator=loss_creator, num_workers=num_workers, config=config) stats = trainer.train(nb_epoch=epochs) for s in stats: print(pretty_print(s)) worker_stats = trainer.validate(reduce_results=True) val_loss = worker_stats['val_loss'] yseq = trainer.get_model() trainer.shutdown() return yseq, val_loss
def go_train(config): trainer = ppo.PPOTrainer(config=config, env="continuousDoubleAuction-v0") if is_restore == True: trainer.restore(restore_path) g_store = ray.util.get_actor("g_store") result = None for i in range(num_iters): result = trainer.train() print(pretty_print(result)) # includes result["custom_metrics"] print("training loop = {} of {}".format(i + 1, num_iters)) print("eps sampled so far {}".format( ray.get(g_store.get_eps_counter.remote()))) print("result['experiment_id']", result["experiment_id"]) return result
def __repr__(self): name = self.__class__.__name__ args = [f"{self.observation_space},", f"{self.action_space},"] config = pretty_print(self.config).rstrip("\n") if "\n" in config: config = textwrap.indent(config, " " * 2) config = "{\n" + config + "\n}" args += [config] args_repr = "\n".join(args) args_repr = textwrap.indent(args_repr, " " * 2) constructor = f"{name}(\n{args_repr}\n)" else: args += [config] args_repr = " ".join(args[1:-1]) constructor = f"{name}({args_repr})" return constructor
def train(num_iters, checkpoint_freq): obs_space = spaces.Dict({ 'obs': spaces.Box(low=-0.5, high=1.5, shape=(32, 32, 3), dtype=np.float32), 'action_mask': spaces.Box(low=0, high=1, shape=(5, ), dtype=np.int32) }) act_space = spaces.Discrete(n=5) trainer = DQNTrainer( env='SUMOEnv-v0', config={ 'model': { 'custom_model': 'adaptive-trafficlight', 'custom_options': {}, }, 'multiagent': { 'policy_graphs': { 'default_policy_graph': ( DQNPolicyGraph, obs_space, act_space, {}, ), }, 'policy_mapping_fn': function(lambda _: 'default_policy_graph'), }, 'hiddens': [], # Don't postprocess the action scores 'callbacks': { 'on_episode_end': function(on_episode_end), }, # 'num_workers': 4, # 'num_gpus_per_worker': 0.25, # All workers on a single GPU 'timesteps_per_iteration': 20000, }) for i in range(num_iters): print(f'== Iteration {i}==') print(pretty_print(trainer.train())) if i % checkpoint_freq == 0: checkpoint = trainer.save() print(f'\nCheckpoint saved at {checkpoint}\n')
def _extract_total_episode_eval_metrics(evaluated, totals, display_outputs=False): episode_totals = {} for k, rer in evaluated.items(): from duckietown_world.rules import RuleEvaluationResult assert isinstance(rer, RuleEvaluationResult) for km, evaluated_metric in rer.metrics.items(): assert isinstance(evaluated_metric, EvaluatedMetric) episode_totals[k] = evaluated_metric.total if not (k in totals): totals[k] = [evaluated_metric.total] else: totals[k].append(evaluated_metric.total) if display_outputs: logger.info("\nEpisode total metrics: \n {}".format( pretty_print(episode_totals))) return totals
"policies_to_train": ["dqn_policy"], }, "gamma": 0.95, "n_step": 3, }) # disable DQN exploration when used by the PPO trainer ppo_trainer.optimizer.foreach_evaluator( lambda ev: ev.for_policy( lambda pi: pi.set_epsilon(0.0), policy_id="dqn_policy")) # You should see both the printed X and Y approach 200 as this trains: # info: # policy_reward_mean: # dqn_policy: X # ppo_policy: Y for i in range(args.num_iters): print("== Iteration", i, "==") # improve the DQN policy print("-- DQN --") print(pretty_print(dqn_trainer.train())) # improve the PPO policy print("-- PPO --") print(pretty_print(ppo_trainer.train())) # swap weights to synchronize dqn_trainer.set_weights(ppo_trainer.get_weights(["ppo_policy"])) ppo_trainer.set_weights(dqn_trainer.get_weights(["dqn_policy"]))