optimizer.step() epsilon_tracker.frame(engine.state.iteration) if engine.state.iteration % params.target_net_sync == 0: tgt_net.sync() if engine.state.iteration % EVAL_EVERY_FRAME == 0: eval_states = getattr(engine.state, "eval_states", None) if eval_states is None: eval_states = buffer.sample(STATES_TO_EVALUATE) eval_states = [ np.array(transition.state, copy=False) for transition in eval_states ] eval_states = np.array(eval_states, copy=False) engine.state.eval_states = eval_states engine.state.metrics["values"] = \ common.calc_values_of_states(eval_states, net, device) return { "loss": loss_v.item(), "epsilon": selector.epsilon, } engine = Engine(process_batch) common.setup_ignite(engine, params, exp_source, f"{NAME}={args.double}", extra_metrics=('values', )) engine.run( common.batch_generator(buffer, params.replay_initial, params.batch_size))
if is_done: break engine.state.metrics['val_reward'] = reward engine.state.metrics['val_steps'] = steps print("Validation got %.3f reward in %d steps" % (reward, steps)) best_val_reward = getattr(engine.state, "best_val_reward", None) if best_val_reward is None: engine.state.best_val_reward = reward elif best_val_reward < reward: print("Best validation reward updated: %s -> %s" % (best_val_reward, reward)) save_prep_name = save_path / ("best_val_%.3f_p.dat" % reward) save_net_name = save_path / ("best_val_%.3f_n.dat" % reward) torch.save(prep.state_dict(), save_prep_name) torch.save(net.state_dict(), save_net_name) engine.state.best_val_reward = reward @engine.on(ptan.ignite.EpisodeEvents.BEST_REWARD_REACHED) def best_reward_updated(trainer: Engine): reward = trainer.state.metrics['avg_reward'] if reward > 0: save_prep_name = save_path / ("best_train_%.3f_p.dat" % reward) save_net_name = save_path / ("best_train_%.3f_n.dat" % reward) torch.save(prep.state_dict(), save_prep_name) torch.save(net.state_dict(), save_net_name) print("%d: best avg training reward: %.3f, saved" % (trainer.state.iteration, reward)) engine.run( common.batch_generator(buffer, params.replay_initial, BATCH_SIZE))
print("%d: val:%s" % (engine.state.iteration, res)) for key, val in res.items(): engine.state.metrics[key + "_val"] = val val_reward = res["episode_reward"] if getattr(engine.state, "best_val_reward", None) is None: engine.state.best_val_reward = val_reward if engine.state.best_val_reward < val_reward: print("Best validation reward updated: %.3f -> %.3f, model saved" % (engine.state.best_val_reward, val_reward)) engine.state.best_val_reward = val_reward path = saves_path / ("val_reward-%.3f.data" % val_reward) torch.save(net.state_dict(), path) event = ptan.ignite.PeriodEvents.ITERS_10000_COMPLETED tst_metrics = [m + "_tst" for m in validation.METRICS] tst_handler = tb_logger.OutputHandler(tag="test", metric_names=tst_metrics) tb.attach(engine, log_handler=tst_handler, event_name=event) val_metrics = [m + "_val" for m in validation.METRICS] val_handler = tb_logger.OutputHandler(tag="validation", metric_names=val_metrics) tb.attach(engine, log_handler=val_handler, event_name=event) engine.run(common.batch_generator(buffer, REPLAY_INITIAL, BATCH_SIZE))
'avg_reward', LM_STOP_AVG_REWARD) > LM_STOP_AVG_REWARD: print("Mean reward reached %.2f, stop pretraining" % LM_STOP_AVG_REWARD) engine.should_terminate = True return { "loss": loss_t.item(), } engine = Engine(process_batch) run_name = f"lm-{args.params}_{args.run}" save_path = pathlib.Path("saves") / run_name save_path.mkdir(parents=True, exist_ok=True) common.setup_ignite(engine, exp_source, run_name) try: engine.run(common.batch_generator(buffer, BATCH_SIZE, BATCH_SIZE)) except KeyboardInterrupt: print("Interrupt got, saving the model...") torch.save(prep.state_dict(), save_path / "prep.dat") torch.save(cmd.state_dict(), save_path / "cmd.dat") print("Using preprocessor and command generator") prep.train(False) cmd.train(False) val_env = gym.make(val_env_id) val_env = preproc.TextWorldPreproc(val_env, use_admissible_commands=False, keep_admissible_commands=True, reward_wrong_last_command=-0.1)
engine = Engine(process_batch) common.setup_ignite(engine, PARAMS, exp_source, args.name, extra_metrics=('test_reward', 'test_steps')) best_test_reward = None @engine.on(ptan_ignite.PeriodEvents.ITERS_10000_COMPLETED) def test_network(engine): net.train(False) reward, steps = test_model(net, device, config) net.train(True) engine.state.metrics['test_reward'] = reward engine.state.metrics['test_steps'] = steps print("Test done: got %.3f reward after %.2f steps" % (reward, steps)) global best_test_reward if best_test_reward is None: best_test_reward = reward elif best_test_reward < reward: print("Best test reward updated %.3f <- %.3f, save model" % (best_test_reward, reward)) best_test_reward = reward torch.save(net.state_dict(), os.path.join(saves_path, "best_%.3f.dat" % reward)) engine.run( common.batch_generator(buffer, PARAMS.replay_initial, PARAMS.batch_size))
loss_t = model.pretrain_policy_loss(cmd, commands, obs_t) loss_t.backward() optimizer.step() if engine.state.metrics.get('avg_reward', LM_STOP_AVG_REWARD) > LM_STOP_AVG_REWARD: print("Mean reward reached %.2f, stop pretraining" % LM_STOP_AVG_REWARD) engine.should_terminate = True return { "loss": loss_t.item(), } engine = Engine(process_batch) run_name = f"lm-{args.params}_{args.run}" common.setup_ignite(engine, exp_source, run_name) engine.run(common.batch_generator(buffer, BATCH_SIZE, BATCH_SIZE)) torch.save(prep.state_dict(), "prep.dat") torch.save(cmd.state_dict(), "cmd.dat") prep.load_state_dict(torch.load("prep.dat")) cmd.load_state_dict(torch.load("cmd.dat")) # DQN training using Preprocessor and Command generator as part of the environment val_env = gym.make(val_env_id) val_env = preproc.TextWorldPreproc(val_env) net = model.DQNModel(obs_size=prep.obs_enc_size, cmd_size=prep.obs_enc_size).to(device) tgt_net = ptan.agent.TargetNet(net) agent = model.CmdDQNAgent(env, net, cmd, prep, epsilon=1, device=device) exp_source = ptan.experience.ExperienceSourceFirstLast(