def eval_minigrid(args): device = 'cuda' env = make_minigrid_env(args) state_shape = env.observation_space.shape action_shape = env.env.action_space.n net = DQN(state_shape[2], state_shape[0], state_shape[1], action_shape, device).to(device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = DQNPolicy(net, optim, args.gamma, args.n_step, target_update_freq=args.target_update_freq) # load a previous policy if args.resume_path: subdir = os.listdir(args.resume_path) for i in subdir: if not i.startswith("Q"): path = os.path.join(args.resume_path, i, "policy-%d.pth" % args.n) policy.load_state_dict(torch.load(path, map_location=device)) print("Loaded agent from: ", path) env.reset() action = None Q_table = {} i = 0 while True: i += 1 if i > 10000: break if action is None: action = 4 action = np.random.randint(3) state, reward, done, _ = env.step(action) pos = tuple(env.agent_pos) if pos in Q_table.keys(): continue value = net( state.reshape(1, state_shape[0], state_shape[1], state_shape[2]))[0].detach().cpu().numpy() # action = np.argmax(value) Q_table[pos] = value with open(os.path.join(args.resume_path, "Q_table%d.txt" % args.n), 'w') as f: for value, key in zip(Q_table.values(), Q_table.keys()): print(key, ":", value, file=f) with open(os.path.join(args.resume_path, "Q_tablepickle%d" % args.n), 'wb') as f: pickle.dump(Q_table, f)
def test(args=get_args()): # Let's watch its performance! env = LimitWrapper( StateBonus(ImgObsWrapper(gym.make('MiniGrid-FourRooms-v0')))) args.state_shape = env.observation_space.shape args.action_shape = env.env.action_space.shape or env.env.action_space.n # model net = DQN(args.state_shape[0], args.state_shape[1], args.action_shape, args.device) net = net.to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = DQNPolicy(net, optim, args.gamma, args.n_step, use_target_network=args.target_update_freq > 0, target_update_freq=args.target_update_freq) policy.load_state_dict(torch.load('dqn.pth')) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') collector.close()
def test_dqn(args=get_args()): env = make_atari_env(args) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.env.action_space.shape or env.env.action_space.n # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # make environments train_envs = SubprocVectorEnv( [lambda: make_atari_env(args) for _ in range(args.training_num)]) test_envs = SubprocVectorEnv( [lambda: make_atari_env_watch(args) for _ in range(args.test_num)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # define model net = DQN(*args.state_shape, args.action_shape, args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy policy = DQNPolicy(net, optim, args.gamma, args.n_step, target_update_freq=args.target_update_freq) # load a previous policy if args.resume_path: policy.load_state_dict( torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = ReplayBuffer(args.buffer_size, ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack) # collector train_collector = Collector(policy, train_envs, buffer) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'dqn') writer = SummaryWriter(log_path) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): if env.env.spec.reward_threshold: return mean_rewards >= env.spec.reward_threshold elif 'Pong' in args.task: return mean_rewards >= 20 else: return False def train_fn(epoch, env_step): # nature DQN setting, linear decay in the first 1M steps if env_step <= 1e6: eps = args.eps_train - env_step / 1e6 * \ (args.eps_train - args.eps_train_final) else: eps = args.eps_train_final policy.set_eps(eps) writer.add_scalar('train/eps', eps, global_step=env_step) def test_fn(epoch, env_step): policy.set_eps(args.eps_test) # watch agent's performance def watch(): print("Setup test envs ...") policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") buffer = ReplayBuffer(args.buffer_size, ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack) collector = Collector(policy, test_envs, buffer) result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=[1] * args.test_num, render=args.render) pprint.pprint(result) if args.watch: watch() exit(0) # test train_collector and start filling replay buffer train_collector.collect(n_step=args.batch_size * 4) # trainer result = offpolicy_trainer(policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.collect_per_step, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False) pprint.pprint(result) watch()
def test_dqn(args=get_args()): env = make_atari_env(args) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.env.action_space.shape or env.env.action_space.n # should be N_FRAMES x H x W print("Observations shape: ", args.state_shape) print("Actions shape: ", args.action_shape) # make environments train_envs = SubprocVectorEnv( [lambda: make_atari_env(args) for _ in range(args.training_num)]) test_envs = SubprocVectorEnv( [lambda: make_atari_env_watch(args) for _ in range(args.test_num)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # log log_path = os.path.join(args.logdir, args.task, 'embedding') embedding_writer = SummaryWriter(log_path + '/with_init') embedding_net = embedding_prediction.Prediction( *args.state_shape, args.action_shape, args.device).to(device=args.device) embedding_net.apply(embedding_prediction.weights_init) if args.embedding_path: embedding_net.load_state_dict(torch.load(log_path + '/embedding.pth')) print("Loaded agent from: ", log_path + '/embedding.pth') # numel_list = [p.numel() for p in embedding_net.parameters()] # print(sum(numel_list), numel_list) pre_buffer = ReplayBuffer(args.buffer_size, save_only_last_obs=True, stack_num=args.frames_stack) pre_test_buffer = ReplayBuffer(args.buffer_size // 100, save_only_last_obs=True, stack_num=args.frames_stack) train_collector = Collector(None, train_envs, pre_buffer) test_collector = Collector(None, test_envs, pre_test_buffer) if args.embedding_data_path: pre_buffer = pickle.load(open(log_path + '/train_data.pkl', 'rb')) pre_test_buffer = pickle.load(open(log_path + '/test_data.pkl', 'rb')) train_collector.buffer = pre_buffer test_collector.buffer = pre_test_buffer print('load success') else: print('collect start') train_collector = Collector(None, train_envs, pre_buffer) test_collector = Collector(None, test_envs, pre_test_buffer) train_collector.collect(n_step=args.buffer_size, random=True) test_collector.collect(n_step=args.buffer_size // 100, random=True) print(len(train_collector.buffer)) print(len(test_collector.buffer)) if not os.path.exists(log_path): os.makedirs(log_path) pickle.dump(pre_buffer, open(log_path + '/train_data.pkl', 'wb')) pickle.dump(pre_test_buffer, open(log_path + '/test_data.pkl', 'wb')) print('collect finish') #使用得到的数据训练编码网络 def part_loss(x, device='cpu'): if not isinstance(x, torch.Tensor): x = torch.tensor(x, device=device, dtype=torch.float32) x = x.view(128, -1) temp = torch.cat( ((1 - x).pow(2.0).unsqueeze_(0), x.pow(2.0).unsqueeze_(0)), dim=0) temp_2 = torch.min(temp, dim=0)[0] return torch.sum(temp_2) pre_optim = torch.optim.Adam(embedding_net.parameters(), lr=1e-5) scheduler = torch.optim.lr_scheduler.StepLR(pre_optim, step_size=50000, gamma=0.1, last_epoch=-1) test_batch_data = test_collector.sample(batch_size=0) loss_fn = torch.nn.NLLLoss() # train_loss = [] for epoch in range(1, 100001): embedding_net.train() batch_data = train_collector.sample(batch_size=128) # print(batch_data) # print(batch_data['obs'][0] == batch_data['obs'][1]) pred = embedding_net(batch_data['obs'], batch_data['obs_next']) x1 = pred[1] x2 = pred[2] # print(torch.argmax(pred[0], dim=1)) if not isinstance(batch_data['act'], torch.Tensor): act = torch.tensor(batch_data['act'], device=args.device, dtype=torch.int64) # print(pred[0].dtype) # print(act.dtype) # l2_norm = sum(p.pow(2.0).sum() for p in embedding_net.net.parameters()) # loss = loss_fn(pred[0], act) + 0.001 * (part_loss(x1) + part_loss(x2)) / 64 loss_1 = loss_fn(pred[0], act) loss_2 = 0.01 * (part_loss(x1, args.device) + part_loss(x2, args.device)) / 128 loss = loss_1 + loss_2 # print(loss_1) # print(loss_2) embedding_writer.add_scalar('training loss1', loss_1.item(), epoch) embedding_writer.add_scalar('training loss2', loss_2, epoch) embedding_writer.add_scalar('training loss', loss.item(), epoch) # train_loss.append(loss.detach().item()) pre_optim.zero_grad() loss.backward() pre_optim.step() scheduler.step() if epoch % 10000 == 0 or epoch == 1: print(pre_optim.state_dict()['param_groups'][0]['lr']) # print("Epoch: %d,Train: Loss: %f" % (epoch, float(loss.item()))) correct = 0 numel_list = [p for p in embedding_net.parameters()][-2] print(numel_list) embedding_net.eval() with torch.no_grad(): test_pred, x1, x2, _ = embedding_net( test_batch_data['obs'], test_batch_data['obs_next']) if not isinstance(test_batch_data['act'], torch.Tensor): act = torch.tensor(test_batch_data['act'], device=args.device, dtype=torch.int64) loss_1 = loss_fn(test_pred, act) loss_2 = 0.01 * (part_loss(x1, args.device) + part_loss(x2, args.device)) / 128 loss = loss_1 + loss_2 embedding_writer.add_scalar('test loss', loss.item(), epoch) # print("Test Loss: %f" % (float(loss))) print(torch.argmax(test_pred, dim=1)) print(act) correct += int((torch.argmax(test_pred, dim=1) == act).sum()) print('Acc:', correct / len(test_batch_data)) torch.save(embedding_net.state_dict(), os.path.join(log_path, 'embedding.pth')) embedding_writer.close() # plt.figure() # plt.plot(np.arange(100000),train_loss) # plt.show() exit() #构建hash表 # log log_path = os.path.join(args.logdir, args.task, 'dqn') writer = SummaryWriter(log_path) # define model net = DQN(*args.state_shape, args.action_shape, args.device).to(device=args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy policy = DQNPolicy(net, optim, args.gamma, args.n_step, target_update_freq=args.target_update_freq) # load a previous policy if args.resume_path: policy.load_state_dict(torch.load(args.resume_path)) print("Loaded agent from: ", args.resume_path) # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM pre_buffer.reset() buffer = ReplayBuffer(args.buffer_size, ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack) # collector # train_collector中传入preprocess_fn对奖励进行重构 train_collector = Collector(policy, train_envs, buffer) test_collector = Collector(policy, test_envs) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(x): if env.env.spec.reward_threshold: return x >= env.spec.reward_threshold elif 'Pong' in args.task: return x >= 20 def train_fn(x): # nature DQN setting, linear decay in the first 1M steps now = x * args.collect_per_step * args.step_per_epoch if now <= 1e6: eps = args.eps_train - now / 1e6 * \ (args.eps_train - args.eps_train_final) policy.set_eps(eps) else: policy.set_eps(args.eps_train_final) print("set eps =", policy.eps) def test_fn(x): policy.set_eps(args.eps_test) # watch agent's performance def watch(): print("Testing agent ...") policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=[1] * args.test_num, render=1 / 30) pprint.pprint(result) if args.watch: watch() exit(0) # test train_collector and start filling replay buffer train_collector.collect(n_step=args.batch_size * 4) # trainer result = offpolicy_trainer(policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.collect_per_step, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False) pprint.pprint(result) watch()
def test_dqn(args=get_args()): env = make_atari_env(args) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # make environments train_envs = ShmemVectorEnv( [lambda: make_atari_env(args) for _ in range(args.training_num)]) test_envs = ShmemVectorEnv( [lambda: make_atari_env_watch(args) for _ in range(args.test_num)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # define model net = DQN(*args.state_shape, args.action_shape, args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy policy = DQNPolicy(net, optim, args.gamma, args.n_step, target_update_freq=args.target_update_freq) # load a previous policy if args.resume_path: policy.load_state_dict( torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack) # collector train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # log log_path = os.path.join(args.logdir, args.task, 'dqn') if args.logger == "tensorboard": writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger = TensorboardLogger(writer) else: logger = WandbLogger( save_interval=1, project=args.task, name='dqn', run_id=args.resume_id, config=args, ) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): if env.spec.reward_threshold: return mean_rewards >= env.spec.reward_threshold elif 'Pong' in args.task: return mean_rewards >= 20 else: return False def train_fn(epoch, env_step): # nature DQN setting, linear decay in the first 1M steps if env_step <= 1e6: eps = args.eps_train - env_step / 1e6 * \ (args.eps_train - args.eps_train_final) else: eps = args.eps_train_final policy.set_eps(eps) if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/eps": eps}) def test_fn(epoch, env_step): policy.set_eps(args.eps_test) def save_checkpoint_fn(epoch, env_step, gradient_step): # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html ckpt_path = os.path.join(log_path, 'checkpoint.pth') torch.save({'model': policy.state_dict()}, ckpt_path) return ckpt_path # watch agent's performance def watch(): print("Setup test envs ...") policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") buffer = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs), ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack) collector = Collector(policy, test_envs, buffer, exploration_noise=True) result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) rew = result["rews"].mean() print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') if args.watch: watch() exit(0) # test train_collector and start filling replay buffer train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger, update_per_step=args.update_per_step, test_in_train=False, resume_from_log=args.resume_id is not None, save_checkpoint_fn=save_checkpoint_fn, ) pprint.pprint(result) watch()
def test_dqn(args=get_args()): env, train_envs, test_envs = make_atari_env( args.task, args.seed, args.training_num, args.test_num, scale=args.scale_obs, frame_stack=args.frames_stack, ) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) # define model net = DQN(*args.state_shape, args.action_shape, args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy policy = DQNPolicy( net, optim, args.gamma, args.n_step, target_update_freq=args.target_update_freq ) if args.icm_lr_scale > 0: feature_net = DQN( *args.state_shape, args.action_shape, args.device, features_only=True ) action_dim = np.prod(args.action_shape) feature_dim = feature_net.output_dim icm_net = IntrinsicCuriosityModule( feature_net.net, feature_dim, action_dim, hidden_sizes=[512], device=args.device ) icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) policy = ICMPolicy( policy, icm_net, icm_optim, args.icm_lr_scale, args.icm_reward_scale, args.icm_forward_loss_weight ).to(args.device) # load a previous policy if args.resume_path: policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer( args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack ) # collector train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") args.algo_name = "dqn_icm" if args.icm_lr_scale > 0 else "dqn" log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) log_path = os.path.join(args.logdir, log_name) # logger if args.logger == "wandb": logger = WandbLogger( save_interval=1, name=log_name.replace(os.path.sep, "__"), run_id=args.resume_id, config=args, project=args.wandb_project, ) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) if args.logger == "tensorboard": logger = TensorboardLogger(writer) else: # wandb logger.load(writer) def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards): if env.spec.reward_threshold: return mean_rewards >= env.spec.reward_threshold elif "Pong" in args.task: return mean_rewards >= 20 else: return False def train_fn(epoch, env_step): # nature DQN setting, linear decay in the first 1M steps if env_step <= 1e6: eps = args.eps_train - env_step / 1e6 * \ (args.eps_train - args.eps_train_final) else: eps = args.eps_train_final policy.set_eps(eps) if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/eps": eps}) def test_fn(epoch, env_step): policy.set_eps(args.eps_test) def save_checkpoint_fn(epoch, env_step, gradient_step): # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html ckpt_path = os.path.join(log_path, "checkpoint.pth") torch.save({"model": policy.state_dict()}, ckpt_path) return ckpt_path # watch agent's performance def watch(): print("Setup test envs ...") policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") buffer = VectorReplayBuffer( args.buffer_size, buffer_num=len(test_envs), ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() result = test_collector.collect( n_episode=args.test_num, render=args.render ) rew = result["rews"].mean() print(f"Mean reward (over {result['n/ep']} episodes): {rew}") if args.watch: watch() exit(0) # test train_collector and start filling replay buffer train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, update_per_step=args.update_per_step, test_in_train=False, resume_from_log=args.resume_id is not None, save_checkpoint_fn=save_checkpoint_fn, ) pprint.pprint(result) watch()
def test_dqn(args=get_args()): env = make_atari_env(args) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.env.action_space.shape or env.env.action_space.n # should be N_FRAMES x H x W print("Observations shape: ", args.state_shape) print("Actions shape: ", args.action_shape) # make environments train_envs = SubprocVectorEnv( [lambda: make_atari_env(args) for _ in range(args.training_num)]) test_envs = SubprocVectorEnv( [lambda: make_atari_env_watch(args) for _ in range(1)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) # define model net = DQN(*args.state_shape, args.action_shape, args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy policy = DQNPolicy(net, optim, args.gamma, args.n_step, target_update_freq=args.target_update_freq) # load a previous policy if args.resume_path: policy.load_state_dict(torch.load(args.resume_path)) print("Loaded agent from: ", args.resume_path) if args.target_model_path: victim_policy = copy.deepcopy(policy) victim_policy.load_state_dict(torch.load(args.target_model_path)) print("Loaded victim agent from: ", args.target_model_path) else: victim_policy = policy args.target_policy, args.policy = "dqn", "dqn" args.perfect_attack = False adv_net = make_victim_network(args, victim_policy) adv_atk, _ = make_img_adv_attack(args, adv_net, targeted=False) buffer = ReplayBuffer(args.buffer_size, ignore_obs_next=True) # collector train_collector = adversarial_training_collector( policy, train_envs, adv_atk, buffer, atk_frequency=args.atk_freq, device=args.device) test_collector = adversarial_training_collector( policy, test_envs, adv_atk, buffer, atk_frequency=args.atk_freq, test=True, device=args.device) # log log_path = os.path.join(args.logdir, args.task, 'dqn') writer = SummaryWriter(log_path) def save_fn(policy, policy_name='policy.pth'): torch.save(policy.state_dict(), os.path.join(log_path, policy_name)) def stop_fn(x): return 0 def train_fn(epoch, env_step): # nature DQN setting, linear decay in the first 1M steps if env_step <= 1e6: eps = args.eps_train - env_step / 1e6 * \ (args.eps_train - args.eps_train_final) else: eps = args.eps_train_final policy.set_eps(eps) writer.add_scalar('train/eps', eps, global_step=env_step) print("set eps =", policy.eps) def test_fn(epoch, env_step): policy.set_eps(args.eps_test) # watch agent's performance def watch(): assert args.target_model_path is not None print("Testing agent ...") policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=[args.test_num], render=args.render) pprint.pprint(result) if args.watch: watch() exit(0) # test train_collector and start filling replay buffer train_collector.collect(n_step=args.batch_size * 4) # trainer result = offpolicy_trainer(policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.collect_per_step, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False) pprint.pprint(result) watch()