def learn(self, batch: Batch, batch_size: int, repeat: int, **kwargs) -> Dict[str, List[float]]: self._batch = batch_size r = batch.returns if self._rew_norm and not np.isclose(r.std(), 0): batch.returns = (r - r.mean()) / r.std() losses, actor_losses, vf_losses, ent_losses = [], [], [], [] for _ in range(repeat): for b in batch.split(batch_size): self.optim.zero_grad() dist = self(b).dist v = self.critic(b.obs) a = to_torch(b.act, device=v.device) r = to_torch(b.returns, device=v.device) a_loss = -(dist.log_prob(a) * (r - v).detach()).mean() vf_loss = F.mse_loss(r[:, None], v) ent_loss = dist.entropy().mean() loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss loss.backward() if self._grad_norm: nn.utils.clip_grad_norm_(list(self.actor.parameters()) + list(self.critic.parameters()), max_norm=self._grad_norm) self.optim.step() actor_losses.append(a_loss.item()) vf_losses.append(vf_loss.item()) ent_losses.append(ent_loss.item()) losses.append(loss.item()) return { 'loss': losses, 'loss/actor': actor_losses, 'loss/vf': vf_losses, 'loss/ent': ent_losses, }
def forward(self, s, a=None, **kwargs): s = to_torch(s, device=self.device, dtype=torch.float32) s = s.flatten(1) if a is not None: a = to_torch(a, device=self.device, dtype=torch.float32) a = a.flatten(1) s = torch.cat([s, a], dim=1) logits, h = self.preprocess(s) logits = self.last(logits) return logits
def forward(self, s, a=None, info={}): """(s, a) -> logits -> Q(s, a)""" s = to_torch(s, device=self.device, dtype=torch.float32) s = s.flatten(1) if a is not None: a = to_torch(a, device=self.device, dtype=torch.float32) a = a.flatten(1) s = torch.cat([s, a], dim=1) logits, h = self.preprocess(s) logits = self.last(logits) return logits
def test_utils_to_torch(): batch = Batch(a=np.ones((1, ), dtype=np.float64), b=Batch(c=np.ones((1, ), dtype=np.float64), d=torch.ones((1, ), dtype=torch.float64))) a_torch_float = to_torch(batch.a, dtype=torch.float32) assert a_torch_float.dtype == torch.float32 a_torch_double = to_torch(batch.a, dtype=torch.float64) assert a_torch_double.dtype == torch.float64 batch_torch_float = to_torch(batch, dtype=torch.float32) assert batch_torch_float.a.dtype == torch.float32 assert batch_torch_float.b.c.dtype == torch.float32 assert batch_torch_float.b.d.dtype == torch.float32
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: self.optim.zero_grad() if self.action_type == "continuous": # regression act = self(batch).act act_target = to_torch(batch.act, dtype=torch.float32, device=act.device) loss = F.mse_loss(act, act_target) # type: ignore elif self.action_type == "discrete": # classification act = F.log_softmax(self(batch).logits, dim=-1) act_target = to_torch(batch.act, dtype=torch.long, device=act.device) loss = F.nll_loss(act, act_target) # type: ignore loss.backward() self.optim.step() return {"loss": loss.item()}
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: self.optim.zero_grad() if self.mode == 'continuous': a = self(batch).act a_ = to_torch(batch.act, dtype=torch.float32, device=a.device) loss = F.mse_loss(a, a_) elif self.mode == 'discrete': # classification a = self(batch).logits a_ = to_torch(batch.act, dtype=torch.long, device=a.device) loss = F.nll_loss(a, a_) loss.backward() self.optim.step() return {'loss': loss.item()}
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: self.optim.zero_grad() if self.mode == "continuous": # regression a = self(batch).act a_ = to_torch(batch.act, dtype=torch.float32, device=a.device) loss = F.mse_loss(a, a_) # type: ignore elif self.mode == "discrete": # classification a = self(batch).logits a_ = to_torch(batch.act, dtype=torch.long, device=a.device) loss = F.nll_loss(a, a_) # type: ignore loss.backward() self.optim.step() return {"loss": loss.item()}
def test_utils_to_torch(): batch = Batch(a=np.float64(1.0), b=Batch(c=np.ones((1, ), dtype=np.float32), d=torch.ones((1, ), dtype=torch.float64))) a_torch_float = to_torch(batch.a, dtype=torch.float32) assert a_torch_float.dtype == torch.float32 a_torch_double = to_torch(batch.a, dtype=torch.float64) assert a_torch_double.dtype == torch.float64 batch_torch_float = to_torch(batch, dtype=torch.float32) assert batch_torch_float.a.dtype == torch.float32 assert batch_torch_float.b.c.dtype == torch.float32 assert batch_torch_float.b.d.dtype == torch.float32 array_list = [float('nan'), 1.0] assert to_torch(array_list).dtype == torch.float64
def fqe_eval(policy, buffer): policy = deepcopy(policy) policy = policy.to(device) Fqe = FQE(policy, buffer, q_hidden_features=1024, q_hidden_layers=4, device=device) critic = Fqe.train_estimator(discount=0.99, target_update_period=100, critic_lr=1e-4, num_steps=250000) eval_size = 10000 batch = buffer[:eval_size] data = to_torch(batch, torch.float32, device=device) o0 = data.obs a0 = policy.get_action(o0) init_sa = torch.cat((o0, a0), -1).to(device) with torch.no_grad(): estimate_q0 = critic(init_sa) res = OrderedDict() res["Estimate_q0"] = estimate_q0.mean().item() return res
def forward( self, s1: Union[np.ndarray, torch.Tensor], act: Union[np.ndarray, torch.Tensor], s2: Union[np.ndarray, torch.Tensor], **kwargs: Any ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Mapping: s1, act, s2 -> mse_loss, act_hat.""" s1 = to_torch(s1, dtype=torch.float32, device=self.device) s2 = to_torch(s2, dtype=torch.float32, device=self.device) phi1, phi2 = self.feature_net(s1), self.feature_net(s2) act = to_torch(act, dtype=torch.long, device=self.device) phi2_hat = self.forward_model( torch.cat([phi1, F.one_hot(act, num_classes=self.action_dim)], dim=1) ) mse_loss = 0.5 * F.mse_loss(phi2_hat, phi2, reduction="none").sum(1) act_hat = self.inverse_model(torch.cat([phi1, phi2], dim=1)) return mse_loss, act_hat
def forward( self, s: Union[np.ndarray, torch.Tensor], state: Optional[Dict[str, torch.Tensor]] = None, info: Dict[str, Any] = {}, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """Mapping: s -> flatten -> logits. In the evaluation mode, s should be with shape ``[bsz, dim]``; in the training mode, s should be with shape ``[bsz, len, dim]``. See the code and comment for more detail. """ s = to_torch(s, device=self.device, dtype=torch.float32) # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) # In short, the tensor's shape in training phase is longer than which # in evaluation phase. if len(s.shape) == 2: s = s.unsqueeze(-2) s = self.fc1(s) self.nn.flatten_parameters() if state is None: s, (h, c) = self.nn(s) else: # we store the stack data in [bsz, len, ...] format # but pytorch rnn needs [len, bsz, ...] s, (h, c) = self.nn(s, (state["h"].transpose( 0, 1).contiguous(), state["c"].transpose(0, 1).contiguous())) s = self.fc2(s[:, -1]) # please ensure the first dim is batch size: [bsz, len, ...] return s, { "h": h.transpose(0, 1).detach(), "c": c.transpose(0, 1).detach() }
def forward( self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, **kwargs: Any, ) -> Batch: """Compute action over the given batch data.""" # There is "obs" in the Batch # obs_group: several groups. Each group has a state. obs_group: torch.Tensor = to_torch( # type: ignore batch.obs, device=self.device) act_group = [] for obs in obs_group: # now obs is (state_dim) obs = (obs.reshape(1, -1)).repeat(self.forward_sampled_times, 1) # now obs is (forward_sampled_times, state_dim) # decode(obs) generates action and actor perturbs it act = self.actor(obs, self.vae.decode(obs)) # now action is (forward_sampled_times, action_dim) q1 = self.critic1(obs, act) # q1 is (forward_sampled_times, 1) max_indice = q1.argmax(0) act_group.append(act[max_indice].cpu().data.numpy().flatten()) act_group = np.array(act_group) return Batch(act=act_group)
def forward(self, s, state=None, info={}): """In the evaluation mode, s should be with shape ``[bsz, dim]``; in the training mode, s should be with shape ``[bsz, len, dim]``. See the code and comment for more detail. """ s = to_torch(s, device=self.device, dtype=torch.float32) # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) # In short, the tensor's shape in training phase is longer than which # in evaluation phase. if len(s.shape) == 2: s = s.unsqueeze(-2) s = self.fc1(s) self.nn.flatten_parameters() if state is None: s, (h, c) = self.nn(s) else: # we store the stack data in [bsz, len, ...] format # but pytorch rnn needs [len, bsz, ...] s, (h, c) = self.nn(s, (state['h'].transpose( 0, 1).contiguous(), state['c'].transpose(0, 1).contiguous())) s = self.fc2(s[:, -1]) # please ensure the first dim is batch size: [bsz, len, ...] return s, { 'h': h.transpose(0, 1).detach(), 'c': c.transpose(0, 1).detach() }
def forward(self, s, state=None, info={}): self.device = next(self.parameters()).device s = to_torch(s, device=self.device, dtype=torch.float32) s = s.flatten(1) bs = s.shape[0] w1 = self.model_w1( torch.cat([ s[:, effective_dim_start:effective_dim_end], s[:, self.n + effective_dim_start:self.n + effective_dim_end] ], dim=1)).reshape(bs, -1, self.n) w2 = self.model_w2( torch.cat([ s[:, effective_dim_start:effective_dim_end], s[:, self.n + effective_dim_start:self.n + effective_dim_end] ], dim=1)).reshape(bs, self.m, -1) mu = w2.matmul( torch.tanh(w1.matmul( (s[:, :self.n] - s[:, self.n:]).unsqueeze(-1)))).squeeze(-1) shape = [1] * len(mu.shape) shape[1] = -1 sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp() return (mu, sigma), None
def forward(self, obs, **kwargs): obs = to_torch(obs, device=self.linear.weight.device) out = self.linear(self.preprocess(obs)) # to take care of choices with different number of options mask = torch.arange(self.action_dim).expand(len(out), self.action_dim) >= obs['action_dim'].unsqueeze(1) out[mask.to(out.device)] = float('-inf') return nn.functional.softmax(out, dim=-1), kwargs.get('state', None)
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._iter % self._freq == 0: self.sync_weight() self._iter += 1 target_q = batch.returns.flatten() result = self(batch) imitation_logits = result.imitation_logits current_q = result.q_value[np.arange(len(target_q)), batch.act] act = to_torch(batch.act, dtype=torch.long, device=target_q.device) q_loss = F.smooth_l1_loss(current_q, target_q) i_loss = F.nll_loss( F.log_softmax(imitation_logits, dim=-1), act # type: ignore ) reg_loss = imitation_logits.pow(2).mean() loss = q_loss + i_loss + self._weight_reg * reg_loss self.optim.zero_grad() loss.backward() self.optim.step() return { "loss": loss.item(), "loss/q": q_loss.item(), "loss/i": i_loss.item(), "loss/reg": reg_loss.item(), }
def forward(self, s, state=None, info={}): s = to_torch(s, device=self.device, dtype=torch.float) # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) # In short, the tensor's shape in training phase is longer than which # in evaluation phase. if len(s.shape) == 2: bsz, dim = s.shape length = 1 else: bsz, length, dim = s.shape s = self.fc1(s.view([bsz * length, dim])) s = s.view(bsz, length, -1) self.nn.flatten_parameters() if state is None: s, (h, c) = self.nn(s) else: # we store the stack data in [bsz, len, ...] format # but pytorch rnn needs [len, bsz, ...] s, (h, c) = self.nn(s, (state['h'].transpose( 0, 1).contiguous(), state['c'].transpose(0, 1).contiguous())) s = self.fc2(s[:, -1]) # please ensure the first dim is batch size: [bsz, len, ...] return s, { 'h': h.transpose(0, 1).detach(), 'c': c.transpose(0, 1).detach() }
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._target and self._iter % self._freq == 0: self.sync_weight() self.optim.zero_grad() weight = batch.pop("weight", 1.0) all_dist = self(batch).logits act = to_torch(batch.act, dtype=torch.long, device=all_dist.device) curr_dist = all_dist[np.arange(len(act)), act, :].unsqueeze(2) target_dist = batch.returns.unsqueeze(1) # calculate each element's difference between curr_dist and target_dist u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") huber_loss = ( u * (self.tau_hat - (target_dist - curr_dist).detach().le(0.).float()).abs() ).sum(-1).mean(1) qr_loss = (huber_loss * weight).mean() # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 batch.weight = u.detach().abs().sum(-1).mean(1) # prio-buffer # add CQL loss q = self.compute_q_value(all_dist, None) dataset_expec = q.gather(1, act.unsqueeze(1)).mean() negative_sampling = q.logsumexp(1).mean() min_q_loss = negative_sampling - dataset_expec loss = qr_loss + min_q_loss * self._min_q_weight loss.backward() self.optim.step() self._iter += 1 return { "loss": loss.item(), "loss/qr": qr_loss.item(), "loss/cql": min_q_loss.item(), }
def forward(self, s, state=None, info={}): """Almost the same as :class:`~tianshou.utils.net.common.Recurrent`.""" s = to_torch(s, device=self.device, dtype=torch.float32) # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) # In short, the tensor's shape in training phase is longer than which # in evaluation phase. if len(s.shape) == 2: s = s.unsqueeze(-2) self.nn.flatten_parameters() if state is None: s, (h, c) = self.nn(s) else: # we store the stack data in [bsz, len, ...] format # but pytorch rnn needs [len, bsz, ...] s, (h, c) = self.nn(s, (state['h'].transpose( 0, 1).contiguous(), state['c'].transpose(0, 1).contiguous())) logits = s[:, -1] mu = self.mu(logits) if not self._unbounded: mu = self._max * torch.tanh(mu) shape = [1] * len(mu.shape) shape[1] = -1 sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp() # please ensure the first dim is batch size: [bsz, len, ...] return (mu, sigma), { 'h': h.transpose(0, 1).detach(), 'c': c.transpose(0, 1).detach() }
def forward( self, s: Union[np.ndarray, torch.Tensor], state: Optional[Any] = None, info: Dict[str, Any] = {}, ) -> Tuple[torch.Tensor, Any]: r"""Mapping: s -> Q(s, \*).""" s = to_torch(s, device=self.device, dtype=torch.float32) logits, h = self.preprocess(s, state) logits = self.last(logits) bs, action_num = s.size() mask = torch.ones_like(logits) if 'turn' in info.keys(): for i in range(bs): if info['turn'][i] < self.least_ask and np.sum( info['history'][i][self.disease_num:]) < ( action_num - self.disease_num): mask[i][:self.disease_num] = 0 if 'history' in info.keys(): mask = mask * torch.tensor( np.ones_like(info['history']) - info['history']) # 这里有个乘history的操作 if self.softmax_output: logits = torch.where(mask == 0, torch.full_like(logits, -1e16), logits) logits = F.softmax(logits, dim=-1) return logits, h
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: weight = batch.pop("weight", 1.0) target_q = batch.returns.flatten() act = to_torch(batch.act[:, np.newaxis], device=target_q.device, dtype=torch.long) # critic 1 current_q1 = self.critic1(batch.obs).gather(1, act).flatten() td1 = current_q1 - target_q critic1_loss = (td1.pow(2) * weight).mean() self.critic1_optim.zero_grad() critic1_loss.backward() self.critic1_optim.step() # critic 2 current_q2 = self.critic2(batch.obs).gather(1, act).flatten() td2 = current_q2 - target_q critic2_loss = (td2.pow(2) * weight).mean() self.critic2_optim.zero_grad() critic2_loss.backward() self.critic2_optim.step() batch.weight = (td1 + td2) / 2.0 # prio-buffer # actor dist = self(batch).dist entropy = dist.entropy() with torch.no_grad(): current_q1a = self.critic1(batch.obs) current_q2a = self.critic2(batch.obs) q = torch.min(current_q1a, current_q2a) actor_loss = -(self._alpha * entropy + (dist.probs * q).sum(dim=-1)).mean() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() if self._is_auto_alpha: log_prob = -entropy.detach() + self._target_entropy alpha_loss = -(self._log_alpha * log_prob).mean() self._alpha_optim.zero_grad() alpha_loss.backward() self._alpha_optim.step() self._alpha = self._log_alpha.detach().exp() self.sync_weight() result = { "loss/actor": actor_loss.item(), "loss/critic1": critic1_loss.item(), "loss/critic2": critic2_loss.item(), } if self._is_auto_alpha: result["loss/alpha"] = alpha_loss.item() result["alpha"] = self._alpha.item() # type: ignore return result
def learn(self, batch: Batch, batch_size: int, repeat: int, **kwargs) -> Dict[str, List[float]]: losses = [] r = batch.returns if self._rew_norm and not np.isclose(r.std(), 0): batch.returns = (r - r.mean()) / r.std() for _ in range(repeat): for b in batch.split(batch_size): self.optim.zero_grad() dist = self(b).dist a = to_torch(b.act, device=dist.logits.device) r = to_torch(b.returns, device=dist.logits.device) loss = -(dist.log_prob(a) * r).sum() loss.backward() self.optim.step() losses.append(loss.item()) return {'loss': losses}
def forward( self, s: Union[np.ndarray, torch.Tensor], state: Optional[Any] = None, info: Dict[str, Any] = {}, ) -> Tuple[torch.Tensor, Any]: """Mapping: s -> flatten -> logits.""" if type(s) is tuple: s_0 = s[0] s_1 = s[1] elif type(s) is np.ndarray: if s.dtype == object: s_0 = s[:, 0] s_1 = s[:, 1] else: raise ValueError("No type %s!" % type(s)) if not self.use_cam_obs: if self.use_phy_obs: robot_state = to_torch(np.stack(s_0), device=self.device, dtype=torch.float32) physical_state = to_torch(np.stack(s_1), device=self.device, dtype=torch.float32) logits = self.model([robot_state, physical_state]) else: s = to_torch(s, device=self.device, dtype=torch.float32) s = s.reshape(s.size(0), -1) logits = self.model(s) else: img_top = to_torch(np.stack(s_0).transpose((0, 3, 1, 2)), device=self.device, dtype=torch.float32) robot_state = to_torch(np.stack(s_1), device=self.device, dtype=torch.float32) logits = self.model([img_top, robot_state]) if self.dueling is not None: # Dueling DQN q, v = self.Q(logits), self.V(logits) logits = q - q.mean(dim=1, keepdim=True) + v if self.softmax: logits = torch.softmax(logits, dim=-1) return logits, state
def train_estimator(self, init_critic=None, discount=0.99, target_update_period=100, critic_lr=1e-4, num_steps=250000, polyak=0.0, batch_size=256): min_reward = self.buffer.rew.min() max_reward = self.buffer.rew.max() max_value = (1.2 * max_reward + 0.8 * min_reward) / (1 - discount) min_value = (1.2 * min_reward + 0.8 * max_reward) / (1 - discount) data = self.buffer.sample(batch_size) input_dim = data.obs.shape[-1] + data.act.shape[-1] critic = MLP(input_dim, 1, self.critic_hidden_features, self.critic_hidden_layers).to(self._device) if init_critic is not None: critic.load_state_dict(init_critic.state_dict()) critic_optimizer = torch.optim.Adam(critic.parameters(), lr=critic_lr) target_critic = deepcopy(critic).to(self._device) target_critic.requires_grad_(False) print('Training Fqe...') for t in tqdm(range(num_steps)): batch = self.buffer.sample(batch_size) data = to_torch(batch, torch.float32, device=self._device) r = data.rew terminals = data.done o1 = data.obs a1 = data.act o2 = data.obs_next a2 = self.policy.get_action(o2) q_target = target_critic(torch.cat((o2, a2), -1)).detach() current_discount = discount * (1 - terminals) backup = r + current_discount * q_target backup = torch.clamp(backup, min_value, max_value) # prevent explosion q = critic(torch.cat((o1, a1), -1)) critic_loss = ((q - backup)**2).mean() critic_optimizer.zero_grad() critic_loss.backward() critic_optimizer.step() if t % target_update_period == 0: with torch.no_grad(): for p, p_targ in zip(critic.parameters(), target_critic.parameters()): p_targ.data.mul_(polyak) p_targ.data.add_((1 - polyak) * p.data) return critic
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: with torch.no_grad(): obs_next_result = self(batch, input='obs_next') a_ = obs_next_result.act dev = a_.device batch.act = to_torch(batch.act, dtype=torch.float, device=dev) target_q = torch.min( self.critic1_old(batch.obs_next, a_), self.critic2_old(batch.obs_next, a_), ) - self._alpha * obs_next_result.log_prob rew = to_torch(batch.rew, dtype=torch.float, device=dev)[:, None] done = to_torch(batch.done, dtype=torch.float, device=dev)[:, None] target_q = (rew + (1. - done) * self._gamma * target_q) # critic 1 current_q1 = self.critic1(batch.obs, batch.act) critic1_loss = F.mse_loss(current_q1, target_q) self.critic1_optim.zero_grad() critic1_loss.backward() self.critic1_optim.step() # critic 2 current_q2 = self.critic2(batch.obs, batch.act) critic2_loss = F.mse_loss(current_q2, target_q) self.critic2_optim.zero_grad() critic2_loss.backward() self.critic2_optim.step() # actor obs_result = self(batch) a = obs_result.act current_q1a = self.critic1(batch.obs, a) current_q2a = self.critic2(batch.obs, a) actor_loss = (self._alpha * obs_result.log_prob - torch.min( current_q1a, current_q2a)).mean() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() self.sync_weight() return { 'loss/actor': actor_loss.item(), 'loss/critic1': critic1_loss.item(), 'loss/critic2': critic2_loss.item(), }
def forward( self, x: Union[np.ndarray, torch.Tensor], state: Optional[Any] = None, info: Dict[str, Any] = {}, ) -> Tuple[torch.Tensor, Any]: r"""Mapping: x -> Q(x, \*).""" if not isinstance(x, torch.Tensor): x = to_torch(x, device=self.device, dtype=torch.float32) return self.net(x), state
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: with torch.no_grad(): a_ = self(batch, model='actor_old', input='obs_next').act dev = a_.device noise = torch.randn(size=a_.shape, device=dev) * self._policy_noise if self._noise_clip >= 0: noise = noise.clamp(-self._noise_clip, self._noise_clip) a_ += noise a_ = a_.clamp(self._range[0], self._range[1]) target_q = torch.min( self.critic1_old(batch.obs_next, a_), self.critic2_old(batch.obs_next, a_)) rew = to_torch(batch.rew, dtype=torch.float, device=dev)[:, None] done = to_torch(batch.done, dtype=torch.float, device=dev)[:, None] target_q = (rew + (1. - done) * self._gamma * target_q) # critic 1 current_q1 = self.critic1(batch.obs, batch.act) critic1_loss = F.mse_loss(current_q1, target_q) self.critic1_optim.zero_grad() critic1_loss.backward() self.critic1_optim.step() # critic 2 current_q2 = self.critic2(batch.obs, batch.act) critic2_loss = F.mse_loss(current_q2, target_q) self.critic2_optim.zero_grad() critic2_loss.backward() self.critic2_optim.step() if self._cnt % self._freq == 0: actor_loss = -self.critic1( batch.obs, self(batch, eps=0).act).mean() self.actor_optim.zero_grad() actor_loss.backward() self._last = actor_loss.item() self.actor_optim.step() self.sync_weight() self._cnt += 1 return { 'loss/actor': self._last, 'loss/critic1': critic1_loss.item(), 'loss/critic2': critic2_loss.item(), }
def forward(self, s, state=None, info={}): """s -> flatten -> logits""" s = to_torch(s, device=self.device, dtype=torch.float32) s = s.reshape(s.size(0), -1) logits = self.model(s) if self.dueling is not None: # Dueling DQN q, v = self.Q(logits), self.V(logits) logits = q - q.mean(dim=1, keepdim=True) + v if self.softmax: logits = torch.softmax(logits, dim=-1) return logits, state
def forward(self, s, **kwargs): s = to_torch(s, device=self.device, dtype=torch.float32) # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) # In short, the tensor's shape in training phase is longer than which # in evaluation phase. if len(s.shape) == 2: s = s.unsqueeze(-2) logits, _ = self.nn(s) logits = logits[:, -1] mu = self.mu(logits) shape = [1] * len(mu.shape) shape[1] = -1 sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp() return (mu, sigma), None
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: with torch.no_grad(): target_q = self.critic_old( batch.obs_next, self(batch, model='actor_old', input='obs_next', eps=0).act) dev = target_q.device rew = to_torch(batch.rew, dtype=torch.float, device=dev)[:, None] done = to_torch(batch.done, dtype=torch.float, device=dev)[:, None] target_q = (rew + (1. - done) * self._gamma * target_q) current_q = self.critic(batch.obs, batch.act) critic_loss = F.mse_loss(current_q, target_q) self.critic_optim.zero_grad() critic_loss.backward() self.critic_optim.step() actor_loss = -self.critic(batch.obs, self(batch, eps=0).act).mean() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() self.sync_weight() return { 'loss/actor': actor_loss.item(), 'loss/critic': critic_loss.item(), }