def test_ReplayBuffer(): """ tianshou.data.ReplayBuffer buf.add() buf.get() buf.update() buf.sample() buf.reset() len(buf) :return: """ buf1 = ReplayBuffer(size=15) for i in range(3): buf1.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={}, weight=None) print(len(buf1)) print(buf1.obs) buf2 = ReplayBuffer(size=10) for i in range(15): buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={}, weight=None) print(buf2.obs) buf1.update(buf2) print(buf1.obs) index = [1, 3, 5] # key is an obligatory args print(buf2.get(index, key='obs')) print('--------------------') sample_data, indice = buf2.sample(batch_size=4) print(sample_data, indice) print(sample_data.obs == buf2[indice].obs) print('--------------------') # buf.reset() only resets the index, not the content. print(len(buf2)) buf2.reset() print(len(buf2)) print(buf2) print('--------------------')
def test_episodic_returns(size=2560): fn = BasePolicy.compute_episodic_return buf = ReplayBuffer(20) batch = Batch( done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]), rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]), info=Batch({ 'TimeLimit.truncated': np.array([False, False, False, False, False, True, False, False]) })) for b in batch: b.obs = b.act = 1 buf.add(b) returns, _ = fn(batch, buf, buf.sample_indices(0), gamma=.1, gae_lambda=1) ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7]) assert np.allclose(returns, ans) buf.reset() batch = Batch( done=np.array([0, 1, 0, 1, 0, 1, 0.]), rew=np.array([7, 6, 1, 2, 3, 4, 5.]), ) for b in batch: b.obs = b.act = 1 buf.add(b) returns, _ = fn(batch, buf, buf.sample_indices(0), gamma=.1, gae_lambda=1) ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5]) assert np.allclose(returns, ans) buf.reset() batch = Batch( done=np.array([0, 1, 0, 1, 0, 0, 1.]), rew=np.array([7, 6, 1, 2, 3, 4, 5.]), ) for b in batch: b.obs = b.act = 1 buf.add(b) returns, _ = fn(batch, buf, buf.sample_indices(0), gamma=.1, gae_lambda=1) ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5]) assert np.allclose(returns, ans) buf.reset() batch = Batch( done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]), rew=np.array( [101, 102, 103., 200, 104, 105, 106, 201, 107, 108, 109, 202]), ) for b in batch: b.obs = b.act = 1 buf.add(b) v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3]) returns, _ = fn(batch, buf, buf.sample_indices(0), v, gamma=0.99, gae_lambda=0.95) ground_truth = np.array([ 454.8344, 376.1143, 291.298, 200., 464.5610, 383.1085, 295.387, 201., 474.2876, 390.1027, 299.476, 202. ]) assert np.allclose(returns, ground_truth) buf.reset() batch = Batch(done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]), rew=np.array([ 101, 102, 103., 200, 104, 105, 106, 201, 107, 108, 109, 202 ]), info=Batch({ 'TimeLimit.truncated': np.array([ False, False, False, True, False, False, False, True, False, False, False, False ]) })) for b in batch: b.obs = b.act = 1 buf.add(b) v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3]) returns, _ = fn(batch, buf, buf.sample_indices(0), v, gamma=0.99, gae_lambda=0.95) ground_truth = np.array([ 454.0109, 375.2386, 290.3669, 199.01, 462.9138, 381.3571, 293.5248, 199.02, 474.2876, 390.1027, 299.476, 202. ]) assert np.allclose(returns, ground_truth) if __name__ == '__main__': buf = ReplayBuffer(size) batch = Batch( done=np.random.randint(100, size=size) == 0, rew=np.random.random(size), ) for b in batch: b.obs = b.act = 1 buf.add(b) indices = buf.sample_indices(0) def vanilla(): return compute_episodic_return_base(batch, gamma=.1) def optimized(): return fn(batch, buf, indices, gamma=.1, gae_lambda=1.0) cnt = 3000 print('GAE vanilla', timeit(vanilla, setup=vanilla, number=cnt)) print('GAE optim ', timeit(optimized, setup=optimized, number=cnt))
def test_Fedppo(args=get_args()): torch.set_num_threads(1) # for poor CPU env = gym.make(args.task) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv # train_envs = DummyVectorEnv( # [lambda: gym.make(args.task) for _ in range(args.training_num)]) # # test_envs = gym.make(args.task) # test_envs = DummyVectorEnv( # [lambda: gym.make(args.task) for _ in range(args.test_num)]) if args.data_quantity != 0: env.set_data_quantity(args.data_quantity) if args.data_quality != 0: env.set_data_quality(args.data_quality) if args.psi != 0: env.set_psi(args.psi) if args.nu != 0: env.set_nu(args.nu) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) # train_envs.seed(args.seed) # test_envs.seed(args.seed) # model # server policy server_policy = build_policy(0, args) # client policy ND_policy = build_policy(1, args) RD_policy = build_policy(2, args) FD_policy = build_policy(3, args) # 不用collector,用replaybuffer server_buffer = ReplayBuffer(args.buffer_size) ND_buffer = ReplayBuffer(args.buffer_size) RD_buffer = ReplayBuffer(args.buffer_size) FD_buffer = ReplayBuffer(args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, 'ppo') writer = SummaryWriter(log_path) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) # 这里开始我自己写,自己写trainer和testor # 为了查看server额收敛情况,我们首先不训练client网络。。。 start_time = time.time() _server_obs, _ND_obs, _RD_obs, _FD_obs = env.reset() _server_act = _server_rew = _done = _info = None server_buffer.reset() _ND_act = _ND_rew = _RD_act = _RD_rew = _FD_act = _FD_rew = [None] ND_buffer.reset() RD_buffer.reset() FD_buffer.reset() all_server_costs = [] all_ND_utility = [] all_RD_utility = [] all_FD_utility = [] all_leak_probability = [] for epoch in range(1, 1 + args.epoch): # 每个epoch收集N*T数据,然后用B训练M次 server_costs = [] ND_utility = [] FD_utility = [] RD_utility = [] leak_probability = [] payment = [] expected_time = [] training_time = [] with tqdm.tqdm(total=args.step_per_epoch, desc=f'Epoch #{epoch}', **tqdm_config) as t: while t.n < t.total: # 收集数据,不用梯度 # server _server_obs, _ND_obs, _RD_obs, _FD_obs = env.reset() server_batch = Batch(obs=_server_obs, act=_server_act, rew=_server_rew, done=_done, obs_next=None, info=_info, policy=None) with torch.no_grad(): server_result = server_policy(server_batch, None) _server_policy = [{}] _server_act = to_numpy(server_result.act) # ND ND_batch = Batch(obs=_ND_obs, act=_ND_act, rew=_ND_rew, done=_done, obs_next=None, info=_info, policy=None) with torch.no_grad(): ND_result = ND_policy(ND_batch, None) _ND_policy = [{}] _ND_act = to_numpy(ND_result.act) # RD RD_batch = Batch(obs=_RD_obs, act=_RD_act, rew=_RD_rew, done=_done, obs_next=None, info=_info, policy=None) with torch.no_grad(): RD_result = RD_policy(RD_batch, None) _RD_policy = [{}] _RD_act = to_numpy(RD_result.act) # FD FD_batch = Batch(obs=_FD_obs, act=_FD_act, rew=_FD_rew, done=_done, obs_next=None, info=_info, policy=None) with torch.no_grad(): FD_result = FD_policy(FD_batch, None) _FD_policy = [{}] _FD_act = to_numpy(FD_result.act) # print(_ND_act.shape) server_obs_next, ND_obs_next, RD_obs_next, FD_obs_next, _server_rew, _client_rew, _done, _info = env.step( _server_act[0], _ND_act[0], _RD_act[0], _FD_act[0]) server_costs.append(_server_rew) ND_utility.append(_client_rew[0]) RD_utility.append(_client_rew[1]) FD_utility.append(_client_rew[2]) leak_probability.append(_info[0]["leak"]) payment.append(env.payment) expected_time.append(env.expected_time) training_time.append(env.global_time * env.time_lambda) # 加入replay buffer server_buffer.add( Batch(obs=_server_obs[0], act=_server_act[0], rew=_server_rew[0], done=_done[0], obs_next=server_obs_next[0], info=_info[0], policy=_server_policy[0])) ND_buffer.add( Batch(obs=_ND_obs[0], act=_ND_act[0], rew=_client_rew[0], done=_done[0], obs_next=ND_obs_next[0], info=_info[0], policy=_ND_policy[0])) RD_buffer.add( Batch(obs=_RD_obs[0], act=_RD_act[0], rew=_client_rew[1], done=_done[0], obs_next=RD_obs_next[0], info=_info[0], policy=_RD_policy[0])) FD_buffer.add( Batch(obs=_FD_obs[0], act=_FD_act[0], rew=_client_rew[2], done=_done[0], obs_next=FD_obs_next[0], info=_info[0], policy=_FD_policy[0])) t.update(1) _server_obs = server_obs_next _ND_obs = ND_obs_next _RD_obs = RD_obs_next _FD_obs = FD_obs_next all_server_costs.append(np.array(server_costs).mean()) all_ND_utility.append(np.array(ND_utility).mean()) all_RD_utility.append(np.array(RD_utility).mean()) all_FD_utility.append(np.array(FD_utility).mean()) all_leak_probability.append(np.array(leak_probability).mean()) print("current bandwidth:", env.bandwidth) print("leak signal:", env.leak_NU, env.leak_FU) print("current server cost:", np.array(server_costs).mean()) print("current device utility:", all_ND_utility[-1], all_RD_utility[-1], all_FD_utility[-1]) print("leak probability:", all_leak_probability[-1]) print("server_act:", _server_act[0]) print("device_acts:", _ND_act[0], _RD_act[0], _FD_act[0]) print("payment cost:", np.array(payment).mean()) print("Expected time cost:", np.array(expected_time).mean()) print("Training time cost:", np.array(training_time).mean()) # print("server_act:",_server_act) # print("client_act:",_client_act) print("info:", env.communication_time, env.computation_time, env.K_theta) server_batch_data, server_indice = server_buffer.sample(0) server_batch_data = server_policy.process_fn(server_batch_data, server_buffer, server_indice) server_policy.learn(server_batch_data, args.batch_size, args.repeat_per_collect) server_buffer.reset() ND_batch_data, ND_indice = ND_buffer.sample(0) ND_batch_data = ND_policy.process_fn(ND_batch_data, ND_buffer, ND_indice) ND_policy.learn(ND_batch_data, args.batch_size, args.repeat_per_collect) ND_buffer.reset() RD_batch_data, RD_indice = RD_buffer.sample(0) RD_batch_data = RD_policy.process_fn(RD_batch_data, RD_buffer, RD_indice) RD_policy.learn(RD_batch_data, args.batch_size, args.repeat_per_collect) RD_buffer.reset() FD_batch_data, FD_indice = FD_buffer.sample(0) FD_batch_data = FD_policy.process_fn(FD_batch_data, FD_buffer, FD_indice) FD_policy.learn(FD_batch_data, args.batch_size, args.repeat_per_collect) FD_buffer.reset() print("all_server_cost:", all_server_costs) print("all_ND_utility:", all_ND_utility) print("all_RD_utility:", all_RD_utility) print("all_FD_utility:", all_FD_utility) print("all_leak_probability:", all_leak_probability) plt.plot(all_server_costs) plt.show()
def test_dqn(args=get_args()): env = make_atari_env(args) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.env.action_space.shape or env.env.action_space.n # should be N_FRAMES x H x W print("Observations shape: ", args.state_shape) print("Actions shape: ", args.action_shape) # make environments train_envs = SubprocVectorEnv( [lambda: make_atari_env(args) for _ in range(args.training_num)]) test_envs = SubprocVectorEnv( [lambda: make_atari_env_watch(args) for _ in range(args.test_num)]) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # log log_path = os.path.join(args.logdir, args.task, 'embedding') embedding_writer = SummaryWriter(log_path + '/with_init') embedding_net = embedding_prediction.Prediction( *args.state_shape, args.action_shape, args.device).to(device=args.device) embedding_net.apply(embedding_prediction.weights_init) if args.embedding_path: embedding_net.load_state_dict(torch.load(log_path + '/embedding.pth')) print("Loaded agent from: ", log_path + '/embedding.pth') # numel_list = [p.numel() for p in embedding_net.parameters()] # print(sum(numel_list), numel_list) pre_buffer = ReplayBuffer(args.buffer_size, save_only_last_obs=True, stack_num=args.frames_stack) pre_test_buffer = ReplayBuffer(args.buffer_size // 100, save_only_last_obs=True, stack_num=args.frames_stack) train_collector = Collector(None, train_envs, pre_buffer) test_collector = Collector(None, test_envs, pre_test_buffer) if args.embedding_data_path: pre_buffer = pickle.load(open(log_path + '/train_data.pkl', 'rb')) pre_test_buffer = pickle.load(open(log_path + '/test_data.pkl', 'rb')) train_collector.buffer = pre_buffer test_collector.buffer = pre_test_buffer print('load success') else: print('collect start') train_collector = Collector(None, train_envs, pre_buffer) test_collector = Collector(None, test_envs, pre_test_buffer) train_collector.collect(n_step=args.buffer_size, random=True) test_collector.collect(n_step=args.buffer_size // 100, random=True) print(len(train_collector.buffer)) print(len(test_collector.buffer)) if not os.path.exists(log_path): os.makedirs(log_path) pickle.dump(pre_buffer, open(log_path + '/train_data.pkl', 'wb')) pickle.dump(pre_test_buffer, open(log_path + '/test_data.pkl', 'wb')) print('collect finish') #使用得到的数据训练编码网络 def part_loss(x, device='cpu'): if not isinstance(x, torch.Tensor): x = torch.tensor(x, device=device, dtype=torch.float32) x = x.view(128, -1) temp = torch.cat( ((1 - x).pow(2.0).unsqueeze_(0), x.pow(2.0).unsqueeze_(0)), dim=0) temp_2 = torch.min(temp, dim=0)[0] return torch.sum(temp_2) pre_optim = torch.optim.Adam(embedding_net.parameters(), lr=1e-5) scheduler = torch.optim.lr_scheduler.StepLR(pre_optim, step_size=50000, gamma=0.1, last_epoch=-1) test_batch_data = test_collector.sample(batch_size=0) loss_fn = torch.nn.NLLLoss() # train_loss = [] for epoch in range(1, 100001): embedding_net.train() batch_data = train_collector.sample(batch_size=128) # print(batch_data) # print(batch_data['obs'][0] == batch_data['obs'][1]) pred = embedding_net(batch_data['obs'], batch_data['obs_next']) x1 = pred[1] x2 = pred[2] # print(torch.argmax(pred[0], dim=1)) if not isinstance(batch_data['act'], torch.Tensor): act = torch.tensor(batch_data['act'], device=args.device, dtype=torch.int64) # print(pred[0].dtype) # print(act.dtype) # l2_norm = sum(p.pow(2.0).sum() for p in embedding_net.net.parameters()) # loss = loss_fn(pred[0], act) + 0.001 * (part_loss(x1) + part_loss(x2)) / 64 loss_1 = loss_fn(pred[0], act) loss_2 = 0.01 * (part_loss(x1, args.device) + part_loss(x2, args.device)) / 128 loss = loss_1 + loss_2 # print(loss_1) # print(loss_2) embedding_writer.add_scalar('training loss1', loss_1.item(), epoch) embedding_writer.add_scalar('training loss2', loss_2, epoch) embedding_writer.add_scalar('training loss', loss.item(), epoch) # train_loss.append(loss.detach().item()) pre_optim.zero_grad() loss.backward() pre_optim.step() scheduler.step() if epoch % 10000 == 0 or epoch == 1: print(pre_optim.state_dict()['param_groups'][0]['lr']) # print("Epoch: %d,Train: Loss: %f" % (epoch, float(loss.item()))) correct = 0 numel_list = [p for p in embedding_net.parameters()][-2] print(numel_list) embedding_net.eval() with torch.no_grad(): test_pred, x1, x2, _ = embedding_net( test_batch_data['obs'], test_batch_data['obs_next']) if not isinstance(test_batch_data['act'], torch.Tensor): act = torch.tensor(test_batch_data['act'], device=args.device, dtype=torch.int64) loss_1 = loss_fn(test_pred, act) loss_2 = 0.01 * (part_loss(x1, args.device) + part_loss(x2, args.device)) / 128 loss = loss_1 + loss_2 embedding_writer.add_scalar('test loss', loss.item(), epoch) # print("Test Loss: %f" % (float(loss))) print(torch.argmax(test_pred, dim=1)) print(act) correct += int((torch.argmax(test_pred, dim=1) == act).sum()) print('Acc:', correct / len(test_batch_data)) torch.save(embedding_net.state_dict(), os.path.join(log_path, 'embedding.pth')) embedding_writer.close() # plt.figure() # plt.plot(np.arange(100000),train_loss) # plt.show() exit() #构建hash表 # log log_path = os.path.join(args.logdir, args.task, 'dqn') writer = SummaryWriter(log_path) # define model net = DQN(*args.state_shape, args.action_shape, args.device).to(device=args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy policy = DQNPolicy(net, optim, args.gamma, args.n_step, target_update_freq=args.target_update_freq) # load a previous policy if args.resume_path: policy.load_state_dict(torch.load(args.resume_path)) print("Loaded agent from: ", args.resume_path) # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM pre_buffer.reset() buffer = ReplayBuffer(args.buffer_size, ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack) # collector # train_collector中传入preprocess_fn对奖励进行重构 train_collector = Collector(policy, train_envs, buffer) test_collector = Collector(policy, test_envs) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(x): if env.env.spec.reward_threshold: return x >= env.spec.reward_threshold elif 'Pong' in args.task: return x >= 20 def train_fn(x): # nature DQN setting, linear decay in the first 1M steps now = x * args.collect_per_step * args.step_per_epoch if now <= 1e6: eps = args.eps_train - now / 1e6 * \ (args.eps_train - args.eps_train_final) policy.set_eps(eps) else: policy.set_eps(args.eps_train_final) print("set eps =", policy.eps) def test_fn(x): policy.set_eps(args.eps_test) # watch agent's performance def watch(): print("Testing agent ...") policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=[1] * args.test_num, render=1 / 30) pprint.pprint(result) if args.watch: watch() exit(0) # test train_collector and start filling replay buffer train_collector.collect(n_step=args.batch_size * 4) # trainer result = offpolicy_trainer(policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.collect_per_step, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False) pprint.pprint(result) watch()
class Collector(object): """The :class:`~tianshou.data.Collector` enables the policy to interact with different types of environments conveniently. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param env: an environment or an instance of the :class:`~tianshou.env.BaseVectorEnv` class. :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class, or a list of :class:`~tianshou.data.ReplayBuffer`. If set to ``None``, it will automatically assign a small-size :class:`~tianshou.data.ReplayBuffer`. :param int stat_size: for the moving average of recording speed, defaults to 100. Example: :: policy = PGPolicy(...) # or other policies if you wish env = gym.make('CartPole-v0') replay_buffer = ReplayBuffer(size=10000) # here we set up a collector with a single environment collector = Collector(policy, env, buffer=replay_buffer) # the collector supports vectorized environments as well envs = VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)]) buffers = [ReplayBuffer(size=5000) for _ in range(3)] # you can also pass a list of replay buffer to collector, for multi-env # collector = Collector(policy, envs, buffer=buffers) collector = Collector(policy, envs, buffer=replay_buffer) # collect at least 3 episodes collector.collect(n_episode=3) # collect 1 episode for the first env, 3 for the third env collector.collect(n_episode=[1, 0, 3]) # collect at least 2 steps collector.collect(n_step=2) # collect episodes with visual rendering (the render argument is the # sleep time between rendering consecutive frames) collector.collect(n_episode=1, render=0.03) # sample data with a given number of batch-size: batch_data = collector.sample(batch_size=64) # policy.learn(batch_data) # btw, vanilla policy gradient only # supports on-policy training, so here we pick all data in the buffer batch_data = collector.sample(batch_size=0) policy.learn(batch_data) # on-policy algorithms use the collected data only once, so here we # clear the buffer collector.reset_buffer() For the scenario of collecting data from multiple environments to a single buffer, the cache buffers will turn on automatically. It may return the data more than the given limitation. .. note:: Please make sure the given environment has a time limitation. """ def __init__(self, policy, env, buffer=None, stat_size=100, **kwargs): super().__init__() self.env = env self.env_num = 1 self.collect_step = 0 self.collect_episode = 0 self.collect_time = 0 if buffer is None: self.buffer = ReplayBuffer(100) else: self.buffer = buffer self.policy = policy self.process_fn = policy.process_fn self._multi_env = isinstance(env, BaseVectorEnv) self._multi_buf = False # True if buf is a list # need multiple cache buffers only if storing in one buffer self._cached_buf = [] if self._multi_env: self.env_num = len(env) if isinstance(self.buffer, list): assert len(self.buffer) == self.env_num, \ 'The number of data buffer does not match the number of ' \ 'input env.' self._multi_buf = True elif isinstance(self.buffer, ReplayBuffer): self._cached_buf = [ ListReplayBuffer() for _ in range(self.env_num)] else: raise TypeError('The buffer in data collector is invalid!') self.reset_env() self.reset_buffer() # state over batch is either a list, an np.ndarray, or a torch.Tensor self.state = None self.step_speed = MovAvg(stat_size) self.episode_speed = MovAvg(stat_size) def reset_buffer(self): """Reset the main data buffer.""" if self._multi_buf: for b in self.buffer: b.reset() else: self.buffer.reset() def get_env_num(self): """Return the number of environments the collector has.""" return self.env_num def reset_env(self): """Reset all of the environment(s)' states and reset all of the cache buffers (if need). """ self._obs = self.env.reset() self._act = self._rew = self._done = self._info = None if self._multi_env: self.reward = np.zeros(self.env_num) self.length = np.zeros(self.env_num) else: self.reward, self.length = 0, 0 for b in self._cached_buf: b.reset() def seed(self, seed=None): """Reset all the seed(s) of the given environment(s).""" if hasattr(self.env, 'seed'): return self.env.seed(seed) def render(self, **kwargs): """Render all the environment(s).""" if hasattr(self.env, 'render'): return self.env.render(**kwargs) def close(self): """Close the environment(s).""" if hasattr(self.env, 'close'): self.env.close() def _make_batch(self, data): """Return [data].""" if isinstance(data, np.ndarray): return data[None] else: return np.array([data]) def _reset_state(self, id): """Reset self.state[id].""" if self.state is None: return if isinstance(self.state, list): self.state[id] = None elif isinstance(self.state, dict): for k in self.state: if isinstance(self.state[k], list): self.state[k][id] = None elif isinstance(self.state[k], torch.Tensor) or \ isinstance(self.state[k], np.ndarray): self.state[k][id] = 0 elif isinstance(self.state, torch.Tensor) or \ isinstance(self.state, np.ndarray): self.state[id] = 0 def collect(self, n_step=0, n_episode=0, render=None): """Collect a specified number of step or episode. :param int n_step: how many steps you want to collect. :param n_episode: how many episodes you want to collect (in each environment). :type n_episode: int or list :param float render: the sleep time between rendering consecutive frames, defaults to ``None`` (no rendering). .. note:: One and only one collection number specification is permitted, either ``n_step`` or ``n_episode``. :return: A dict including the following keys * ``n/ep`` the collected number of episodes. * ``n/st`` the collected number of steps. * ``v/st`` the speed of steps per second. * ``v/ep`` the speed of episode per second. * ``rew`` the mean reward over collected episodes. * ``len`` the mean length over collected episodes. """ warning_count = 0 if not self._multi_env: n_episode = np.sum(n_episode) start_time = time.time() assert sum([(n_step != 0), (n_episode != 0)]) == 1, \ "One and only one collection number specification is permitted!" cur_step = 0 cur_episode = np.zeros(self.env_num) if self._multi_env else 0 reward_sum = 0 length_sum = 0 while True: if warning_count >= 100000: warnings.warn( 'There are already many steps in an episode. ' 'You should add a time limitation to your environment!', Warning) if self._multi_env: batch_data = Batch( obs=self._obs, act=self._act, rew=self._rew, done=self._done, obs_next=None, info=self._info) else: batch_data = Batch( obs=self._make_batch(self._obs), act=self._make_batch(self._act), rew=self._make_batch(self._rew), done=self._make_batch(self._done), obs_next=None, info=self._make_batch(self._info)) with torch.no_grad(): result = self.policy(batch_data, self.state) self.state = result.state if hasattr(result, 'state') else None if isinstance(result.act, torch.Tensor): self._act = result.act.detach().cpu().numpy() elif not isinstance(self._act, np.ndarray): self._act = np.array(result.act) else: self._act = result.act obs_next, self._rew, self._done, self._info = self.env.step( self._act if self._multi_env else self._act[0]) if render is not None: self.env.render() if render > 0: time.sleep(render) self.length += 1 self.reward += self._rew if self._multi_env: for i in range(self.env_num): data = { 'obs': self._obs[i], 'act': self._act[i], 'rew': self._rew[i], 'done': self._done[i], 'obs_next': obs_next[i], 'info': self._info[i]} if self._cached_buf: warning_count += 1 self._cached_buf[i].add(**data) elif self._multi_buf: warning_count += 1 self.buffer[i].add(**data) cur_step += 1 else: warning_count += 1 self.buffer.add(**data) cur_step += 1 if self._done[i]: if n_step != 0 or np.isscalar(n_episode) or \ cur_episode[i] < n_episode[i]: cur_episode[i] += 1 reward_sum += self.reward[i] length_sum += self.length[i] if self._cached_buf: cur_step += len(self._cached_buf[i]) self.buffer.update(self._cached_buf[i]) self.reward[i], self.length[i] = 0, 0 if self._cached_buf: self._cached_buf[i].reset() self._reset_state(i) if sum(self._done): obs_next = self.env.reset(np.where(self._done)[0]) if n_episode != 0: if isinstance(n_episode, list) and \ (cur_episode >= np.array(n_episode)).all() or \ np.isscalar(n_episode) and \ cur_episode.sum() >= n_episode: break else: self.buffer.add( self._obs, self._act[0], self._rew, self._done, obs_next, self._info) cur_step += 1 if self._done: cur_episode += 1 reward_sum += self.reward length_sum += self.length self.reward, self.length = 0, 0 self.state = None obs_next = self.env.reset() if n_episode != 0 and cur_episode >= n_episode: break if n_step != 0 and cur_step >= n_step: break self._obs = obs_next self._obs = obs_next if self._multi_env: cur_episode = sum(cur_episode) duration = max(time.time() - start_time, 1e-9) self.step_speed.add(cur_step / duration) self.episode_speed.add(cur_episode / duration) self.collect_step += cur_step self.collect_episode += cur_episode self.collect_time += duration if isinstance(n_episode, list): n_episode = np.sum(n_episode) else: n_episode = max(cur_episode, 1) return { 'n/ep': cur_episode, 'n/st': cur_step, 'v/st': self.step_speed.get(), 'v/ep': self.episode_speed.get(), 'rew': reward_sum / n_episode, 'len': length_sum / n_episode, } def sample(self, batch_size): """Sample a data batch from the internal replay buffer. It will call :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the final batch data. :param int batch_size: ``0`` means it will extract all the data from the buffer, otherwise it will extract the data with the given batch_size. """ if self._multi_buf: if batch_size > 0: lens = [len(b) for b in self.buffer] total = sum(lens) batch_index = np.random.choice( total, batch_size, p=np.array(lens) / total) else: batch_index = np.array([]) batch_data = Batch() for i, b in enumerate(self.buffer): cur_batch = (batch_index == i).sum() if batch_size and cur_batch or batch_size <= 0: batch, indice = b.sample(cur_batch) batch = self.process_fn(batch, b, indice) batch_data.append(batch) else: batch_data, indice = self.buffer.sample(batch_size) batch_data = self.process_fn(batch_data, self.buffer, indice) return batch_data
class Collector(object): """docstring for Collector""" def __init__(self, policy, env, buffer=None, stat_size=100): super().__init__() self.env = env self.env_num = 1 self.collect_step = 0 self.collect_episode = 0 self.collect_time = 0 if buffer is None: self.buffer = ReplayBuffer(100) else: self.buffer = buffer self.policy = policy self.process_fn = policy.process_fn self._multi_env = isinstance(env, BaseVectorEnv) self._multi_buf = False # True if buf is a list # need multiple cache buffers only if storing in one buffer self._cached_buf = [] if self._multi_env: self.env_num = len(env) if isinstance(self.buffer, list): assert len(self.buffer) == self.env_num, \ 'The number of data buffer does not match the number of ' \ 'input env.' self._multi_buf = True elif isinstance(self.buffer, ReplayBuffer): self._cached_buf = [ ListReplayBuffer() for _ in range(self.env_num) ] else: raise TypeError('The buffer in data collector is invalid!') self.reset_env() self.reset_buffer() # state over batch is either a list, an np.ndarray, or a torch.Tensor self.state = None self.step_speed = MovAvg(stat_size) self.episode_speed = MovAvg(stat_size) def reset_buffer(self): if self._multi_buf: for b in self.buffer: b.reset() else: self.buffer.reset() def get_env_num(self): return self.env_num def reset_env(self): self._obs = self.env.reset() self._act = self._rew = self._done = self._info = None if self._multi_env: self.reward = np.zeros(self.env_num) self.length = np.zeros(self.env_num) else: self.reward, self.length = 0, 0 for b in self._cached_buf: b.reset() def seed(self, seed=None): if hasattr(self.env, 'seed'): return self.env.seed(seed) def render(self, **kwargs): if hasattr(self.env, 'render'): return self.env.render(**kwargs) def close(self): if hasattr(self.env, 'close'): self.env.close() def _make_batch(self, data): if isinstance(data, np.ndarray): return data[None] else: return np.array([data]) def collect(self, n_step=0, n_episode=0, render=0): warning_count = 0 if not self._multi_env: n_episode = np.sum(n_episode) start_time = time.time() assert sum([(n_step != 0), (n_episode != 0)]) == 1, \ "One and only one collection number specification permitted!" cur_step = 0 cur_episode = np.zeros(self.env_num) if self._multi_env else 0 reward_sum = 0 length_sum = 0 while True: if warning_count >= 100000: warnings.warn( 'There are already many steps in an episode. ' 'You should add a time limitation to your environment!', Warning) if self._multi_env: batch_data = Batch(obs=self._obs, act=self._act, rew=self._rew, done=self._done, obs_next=None, info=self._info) else: batch_data = Batch(obs=self._make_batch(self._obs), act=self._make_batch(self._act), rew=self._make_batch(self._rew), done=self._make_batch(self._done), obs_next=None, info=self._make_batch(self._info)) result = self.policy(batch_data, self.state) self.state = result.state if hasattr(result, 'state') else None if isinstance(result.act, torch.Tensor): self._act = result.act.detach().cpu().numpy() elif not isinstance(self._act, np.ndarray): self._act = np.array(result.act) else: self._act = result.act obs_next, self._rew, self._done, self._info = self.env.step( self._act if self._multi_env else self._act[0]) if render > 0: self.env.render() time.sleep(render) self.length += 1 self.reward += self._rew if self._multi_env: for i in range(self.env_num): data = { 'obs': self._obs[i], 'act': self._act[i], 'rew': self._rew[i], 'done': self._done[i], 'obs_next': obs_next[i], 'info': self._info[i] } if self._cached_buf: warning_count += 1 self._cached_buf[i].add(**data) elif self._multi_buf: warning_count += 1 self.buffer[i].add(**data) cur_step += 1 else: warning_count += 1 self.buffer.add(**data) cur_step += 1 if self._done[i]: if n_step != 0 or np.isscalar(n_episode) or \ cur_episode[i] < n_episode[i]: cur_episode[i] += 1 reward_sum += self.reward[i] length_sum += self.length[i] if self._cached_buf: cur_step += len(self._cached_buf[i]) self.buffer.update(self._cached_buf[i]) self.reward[i], self.length[i] = 0, 0 if self._cached_buf: self._cached_buf[i].reset() if isinstance(self.state, list): self.state[i] = None elif self.state is not None: if isinstance(self.state[i], dict): self.state[i] = {} else: self.state[i] = self.state[i] * 0 if isinstance(self.state, torch.Tensor): # remove ref count in pytorch (?) self.state = self.state.detach() if sum(self._done): obs_next = self.env.reset(np.where(self._done)[0]) if n_episode != 0: if isinstance(n_episode, list) and \ (cur_episode >= np.array(n_episode)).all() or \ np.isscalar(n_episode) and \ cur_episode.sum() >= n_episode: break else: self.buffer.add(self._obs, self._act[0], self._rew, self._done, obs_next, self._info) cur_step += 1 if self._done: cur_episode += 1 reward_sum += self.reward length_sum += self.length self.reward, self.length = 0, 0 self.state = None obs_next = self.env.reset() if n_episode != 0 and cur_episode >= n_episode: break if n_step != 0 and cur_step >= n_step: break self._obs = obs_next self._obs = obs_next if self._multi_env: cur_episode = sum(cur_episode) duration = time.time() - start_time self.step_speed.add(cur_step / duration) self.episode_speed.add(cur_episode / duration) self.collect_step += cur_step self.collect_episode += cur_episode self.collect_time += duration if isinstance(n_episode, list): n_episode = np.sum(n_episode) else: n_episode = max(cur_episode, 1) return { 'n/ep': cur_episode, 'n/st': cur_step, 'v/st': self.step_speed.get(), 'v/ep': self.episode_speed.get(), 'rew': reward_sum / n_episode, 'len': length_sum / n_episode, } def sample(self, batch_size): if self._multi_buf: if batch_size > 0: lens = [len(b) for b in self.buffer] total = sum(lens) batch_index = np.random.choice(total, batch_size, p=np.array(lens) / total) else: batch_index = np.array([]) batch_data = Batch() for i, b in enumerate(self.buffer): cur_batch = (batch_index == i).sum() if batch_size and cur_batch or batch_size <= 0: batch, indice = b.sample(cur_batch) batch = self.process_fn(batch, b, indice) batch_data.append(batch) else: batch_data, indice = self.buffer.sample(batch_size) batch_data = self.process_fn(batch_data, self.buffer, indice) return batch_data