def save(self): """Save.""" state_dict = { 'pi': self.pi.state_dict(), 'qf1': self.qf1.state_dict(), 'qf2': self.qf2.state_dict(), 'target_pi': self.target_pi.state_dict(), 'target_qf1': self.target_qf1.state_dict(), 'target_qf2': self.target_qf2.state_dict(), 'opt_pi': self.opt_pi.state_dict(), 'opt_qf': self.opt_qf.state_dict(), 'env': misc.env_state_dict(self.env), 't': self.t } buffer_dict = self.buffer.state_dict() state_dict['buffer_format'] = nest.get_structure(buffer_dict) self.ckptr.save(state_dict, self.t) # save buffer seperately and only once (because it can be huge) np.savez(os.path.join(self.ckptr.ckptdir, 'buffer.npz'), **{f'{i:04d}': x for i, x in enumerate(nest.flatten(buffer_dict))})
def save(self): """Save.""" state_dict = { 'pi': self.pi.state_dict(), 'qf1': self.qf1.state_dict(), 'qf2': self.qf2.state_dict(), 'vf': self.vf.state_dict(), 'opt_pi': self.opt_pi.state_dict(), 'opt_qf1': self.opt_qf1.state_dict(), 'opt_qf2': self.opt_qf2.state_dict(), 'opt_vf': self.opt_vf.state_dict(), 'log_alpha': (self.log_alpha if self.automatic_entropy_tuning else None), 'opt_alpha': (self.opt_alpha.state_dict() if self.automatic_entropy_tuning else None), 'env': misc.env_state_dict(self.env), 't': self.t } buffer_dict = self.buffer.state_dict() state_dict['buffer_format'] = nest.get_structure(buffer_dict) self.ckptr.save(state_dict, self.t) # save buffer seperately and only once (because it can be huge) np.savez( os.path.join(self.ckptr.ckptdir, 'buffer.npz'), **{f'{i:04d}': x for i, x in enumerate(nest.flatten(buffer_dict))})
def load(self, t=None): """Load.""" state_dict = self.ckptr.load(t) if state_dict is None: self.t = 0 return self.t self.pi.load_state_dict(state_dict['pi']) self.qf.load_state_dict(state_dict['qf']) self.target_pi.load_state_dict(state_dict['target_pi']) self.target_qf.load_state_dict(state_dict['target_qf']) self.opt_pi.load_state_dict(state_dict['opt_pi']) self.opt_qf.load_state_dict(state_dict['opt_qf']) misc.env_load_state_dict(self.env, state_dict['env']) self.t = state_dict['t'] buffer_format = state_dict['buffer_format'] buffer_state = dict( np.load(os.path.join(self.ckptr.ckptdir, 'buffer.npz'))) buffer_state = nest.flatten(buffer_state) self.buffer.load_state_dict( nest.pack_sequence_as(buffer_state, buffer_format)) self.data_manager.manual_reset() return self.t
def load(self, t=None): """Load state dict.""" state_dict = self.ckptr.load(t) if state_dict is None: self.t = 0 return self.t self.pi.load_state_dict(state_dict['pi']) self.vf.load_state_dict(state_dict['vf']) self.opt_pi.load_state_dict(state_dict['opt_pi']) self.opt_vf.load_state_dict(state_dict['opt_vf']) self.kl_weight = state_dict['kl_weight'] misc.env_load_state_dict(self.env, state_dict['env']) self.ngu.load_state_dict(state_dict['ngu']) self.t = state_dict['t'] buffer_format = state_dict['buffer_format'] buffer_state = dict( np.load(os.path.join(self.ckptr.ckptdir, 'buffer.npz'), allow_pickle=True)) buffer_state = nest.flatten(buffer_state) self.buffer.load_state_dict( nest.pack_sequence_as(buffer_state, buffer_format)) self.buffer.env_reset() return self.t
def __init__(self, logdir, env_fn, policy_fn, qf_fn, nenv=1, optimizer=torch.optim.Adam, buffer_size=10000, frame_stack=1, learning_starts=1000, update_period=1, batch_size=256, policy_lr=1e-3, qf_lr=1e-3, gamma=0.99, target_update_period=1, policy_update_period=1, target_smoothing_coef=0.005, alpha=0.2, automatic_entropy_tuning=True, target_entropy=None, gpu=True, eval_num_episodes=1, record_num_episodes=1, log_period=1000): """Init.""" self.logdir = logdir self.ckptr = Checkpointer(os.path.join(logdir, 'ckpts')) self.env_fn = env_fn self.nenv = nenv self.eval_num_episodes = eval_num_episodes self.record_num_episodes = record_num_episodes self.gamma = gamma self.buffer_size = buffer_size self.frame_stack = frame_stack self.learning_starts = learning_starts self.update_period = update_period self.batch_size = batch_size if target_update_period < self.update_period: self.target_update_period = self.update_period else: self.target_update_period = target_update_period - ( target_update_period % self.update_period) if policy_update_period < self.update_period: self.policy_update_period = self.update_period else: self.policy_update_period = policy_update_period - ( policy_update_period % self.update_period) self.target_smoothing_coef = target_smoothing_coef self.log_period = log_period self.device = torch.device( 'cuda:0' if gpu and torch.cuda.is_available() else 'cpu') self.env = VecEpisodeLogger(env_fn(nenv=nenv)) eval_env = VecFrameStack(self.env, self.frame_stack) self.pi = policy_fn(eval_env) self.qf1 = qf_fn(eval_env) self.qf2 = qf_fn(eval_env) self.target_qf1 = qf_fn(eval_env) self.target_qf2 = qf_fn(eval_env) self.pi.to(self.device) self.qf1.to(self.device) self.qf2.to(self.device) self.target_qf1.to(self.device) self.target_qf2.to(self.device) self.opt_pi = optimizer(self.pi.parameters(), lr=policy_lr) self.opt_qf1 = optimizer(self.qf1.parameters(), lr=qf_lr) self.opt_qf2 = optimizer(self.qf2.parameters(), lr=qf_lr) self.target_qf1.load_state_dict(self.qf1.state_dict()) self.target_qf2.load_state_dict(self.qf2.state_dict()) self.buffer = BatchedReplayBuffer( * [ReplayBuffer(buffer_size, frame_stack) for _ in range(self.nenv)]) self.data_manager = ReplayBufferDataManager(self.buffer, self.env, SACActor(self.pi), self.device, self.learning_starts, self.update_period) self.alpha = alpha self.automatic_entropy_tuning = automatic_entropy_tuning if self.automatic_entropy_tuning: if target_entropy: self.target_entropy = target_entropy else: target_entropies = nest.map_structure( lambda space: -np.prod(space.shape).item(), misc.unpack_space(self.env.action_space)) self.target_entropy = sum(nest.flatten(target_entropies)) self.log_alpha = torch.tensor(np.log([self.alpha]), requires_grad=True, device=self.device, dtype=torch.float32) self.opt_alpha = optimizer([self.log_alpha], lr=policy_lr) else: self.target_entropy = None self.log_alpha = None self.opt_alpha = None self.mse_loss = torch.nn.MSELoss() self.t = 0
def to_tensors(self): return [d.to_tensors() for d in nest.flatten(self.dists)]
def kl(self, other): "KL divergence." kls = nest.map_structure(lambda dists: dists[0].kl(dists[1]), nest.zip_structure(self.dists, other.dists)) return sum(nest.flatten(kls))
def entropy(self): """Entropy.""" entropies = nest.map_structure(lambda dist: dist.entropy(), self.dists) return sum(nest.flatten(entropies))
def from_tensors(self, tensors): flat_dists = [d.from_tensors(t) for d, t in zip(nest.flatten(self.dists), tensors)] return ProductDistribution(nest.pack_sequence_as( flat_dists, nest.get_structure(self.dists)))
def observation(self, obs): return np.concatenate([x.flatten() for x in nest.flatten(obs)])