Exemplo n.º 1
0
 def learn(self, batch: Batch, batch_size: int, repeat: int,
           **kwargs) -> Dict[str, List[float]]:
     self._batch = batch_size
     r = batch.returns
     if self._rew_norm and not np.isclose(r.std(), 0):
         batch.returns = (r - r.mean()) / r.std()
     losses, actor_losses, vf_losses, ent_losses = [], [], [], []
     for _ in range(repeat):
         for b in batch.split(batch_size):
             self.optim.zero_grad()
             dist = self(b).dist
             v = self.critic(b.obs)
             a = to_torch(b.act, device=v.device)
             r = to_torch(b.returns, device=v.device)
             a_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
             vf_loss = F.mse_loss(r[:, None], v)
             ent_loss = dist.entropy().mean()
             loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
             loss.backward()
             if self._grad_norm:
                 nn.utils.clip_grad_norm_(list(self.actor.parameters()) +
                                          list(self.critic.parameters()),
                                          max_norm=self._grad_norm)
             self.optim.step()
             actor_losses.append(a_loss.item())
             vf_losses.append(vf_loss.item())
             ent_losses.append(ent_loss.item())
             losses.append(loss.item())
     return {
         'loss': losses,
         'loss/actor': actor_losses,
         'loss/vf': vf_losses,
         'loss/ent': ent_losses,
     }
Exemplo n.º 2
0
 def forward(self, s, a=None, **kwargs):
     s = to_torch(s, device=self.device, dtype=torch.float32)
     s = s.flatten(1)
     if a is not None:
         a = to_torch(a, device=self.device, dtype=torch.float32)
         a = a.flatten(1)
         s = torch.cat([s, a], dim=1)
     logits, h = self.preprocess(s)
     logits = self.last(logits)
     return logits
Exemplo n.º 3
0
 def forward(self, s, a=None, info={}):
     """(s, a) -> logits -> Q(s, a)"""
     s = to_torch(s, device=self.device, dtype=torch.float32)
     s = s.flatten(1)
     if a is not None:
         a = to_torch(a, device=self.device, dtype=torch.float32)
         a = a.flatten(1)
         s = torch.cat([s, a], dim=1)
     logits, h = self.preprocess(s)
     logits = self.last(logits)
     return logits
Exemplo n.º 4
0
def test_utils_to_torch():
    batch = Batch(a=np.ones((1, ), dtype=np.float64),
                  b=Batch(c=np.ones((1, ), dtype=np.float64),
                          d=torch.ones((1, ), dtype=torch.float64)))
    a_torch_float = to_torch(batch.a, dtype=torch.float32)
    assert a_torch_float.dtype == torch.float32
    a_torch_double = to_torch(batch.a, dtype=torch.float64)
    assert a_torch_double.dtype == torch.float64
    batch_torch_float = to_torch(batch, dtype=torch.float32)
    assert batch_torch_float.a.dtype == torch.float32
    assert batch_torch_float.b.c.dtype == torch.float32
    assert batch_torch_float.b.d.dtype == torch.float32
Exemplo n.º 5
0
 def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
     self.optim.zero_grad()
     if self.action_type == "continuous":  # regression
         act = self(batch).act
         act_target = to_torch(batch.act, dtype=torch.float32, device=act.device)
         loss = F.mse_loss(act, act_target)  # type: ignore
     elif self.action_type == "discrete":  # classification
         act = F.log_softmax(self(batch).logits, dim=-1)
         act_target = to_torch(batch.act, dtype=torch.long, device=act.device)
         loss = F.nll_loss(act, act_target)  # type: ignore
     loss.backward()
     self.optim.step()
     return {"loss": loss.item()}
Exemplo n.º 6
0
 def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
     self.optim.zero_grad()
     if self.mode == 'continuous':
         a = self(batch).act
         a_ = to_torch(batch.act, dtype=torch.float32, device=a.device)
         loss = F.mse_loss(a, a_)
     elif self.mode == 'discrete':  # classification
         a = self(batch).logits
         a_ = to_torch(batch.act, dtype=torch.long, device=a.device)
         loss = F.nll_loss(a, a_)
     loss.backward()
     self.optim.step()
     return {'loss': loss.item()}
Exemplo n.º 7
0
 def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
     self.optim.zero_grad()
     if self.mode == "continuous":  # regression
         a = self(batch).act
         a_ = to_torch(batch.act, dtype=torch.float32, device=a.device)
         loss = F.mse_loss(a, a_)  # type: ignore
     elif self.mode == "discrete":  # classification
         a = self(batch).logits
         a_ = to_torch(batch.act, dtype=torch.long, device=a.device)
         loss = F.nll_loss(a, a_)  # type: ignore
     loss.backward()
     self.optim.step()
     return {"loss": loss.item()}
Exemplo n.º 8
0
def test_utils_to_torch():
    batch = Batch(a=np.float64(1.0),
                  b=Batch(c=np.ones((1, ), dtype=np.float32),
                          d=torch.ones((1, ), dtype=torch.float64)))
    a_torch_float = to_torch(batch.a, dtype=torch.float32)
    assert a_torch_float.dtype == torch.float32
    a_torch_double = to_torch(batch.a, dtype=torch.float64)
    assert a_torch_double.dtype == torch.float64
    batch_torch_float = to_torch(batch, dtype=torch.float32)
    assert batch_torch_float.a.dtype == torch.float32
    assert batch_torch_float.b.c.dtype == torch.float32
    assert batch_torch_float.b.d.dtype == torch.float32
    array_list = [float('nan'), 1.0]
    assert to_torch(array_list).dtype == torch.float64
Exemplo n.º 9
0
    def fqe_eval(policy, buffer):
        policy = deepcopy(policy)
        policy = policy.to(device)

        Fqe = FQE(policy,
                  buffer,
                  q_hidden_features=1024,
                  q_hidden_layers=4,
                  device=device)

        critic = Fqe.train_estimator(discount=0.99,
                                     target_update_period=100,
                                     critic_lr=1e-4,
                                     num_steps=250000)

        eval_size = 10000
        batch = buffer[:eval_size]
        data = to_torch(batch, torch.float32, device=device)
        o0 = data.obs
        a0 = policy.get_action(o0)
        init_sa = torch.cat((o0, a0), -1).to(device)
        with torch.no_grad():
            estimate_q0 = critic(init_sa)
        res = OrderedDict()
        res["Estimate_q0"] = estimate_q0.mean().item()
        return res
Exemplo n.º 10
0
 def forward(
     self, s1: Union[np.ndarray, torch.Tensor],
     act: Union[np.ndarray, torch.Tensor], s2: Union[np.ndarray,
                                                     torch.Tensor], **kwargs: Any
 ) -> Tuple[torch.Tensor, torch.Tensor]:
     r"""Mapping: s1, act, s2 -> mse_loss, act_hat."""
     s1 = to_torch(s1, dtype=torch.float32, device=self.device)
     s2 = to_torch(s2, dtype=torch.float32, device=self.device)
     phi1, phi2 = self.feature_net(s1), self.feature_net(s2)
     act = to_torch(act, dtype=torch.long, device=self.device)
     phi2_hat = self.forward_model(
         torch.cat([phi1, F.one_hot(act, num_classes=self.action_dim)], dim=1)
     )
     mse_loss = 0.5 * F.mse_loss(phi2_hat, phi2, reduction="none").sum(1)
     act_hat = self.inverse_model(torch.cat([phi1, phi2], dim=1))
     return mse_loss, act_hat
Exemplo n.º 11
0
    def forward(
        self,
        s: Union[np.ndarray, torch.Tensor],
        state: Optional[Dict[str, torch.Tensor]] = None,
        info: Dict[str, Any] = {},
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """Mapping: s -> flatten -> logits.

        In the evaluation mode, s should be with shape ``[bsz, dim]``; in the
        training mode, s should be with shape ``[bsz, len, dim]``. See the code
        and comment for more detail.
        """
        s = to_torch(s, device=self.device, dtype=torch.float32)
        # s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
        # In short, the tensor's shape in training phase is longer than which
        # in evaluation phase.
        if len(s.shape) == 2:
            s = s.unsqueeze(-2)
        s = self.fc1(s)
        self.nn.flatten_parameters()
        if state is None:
            s, (h, c) = self.nn(s)
        else:
            # we store the stack data in [bsz, len, ...] format
            # but pytorch rnn needs [len, bsz, ...]
            s, (h, c) = self.nn(s, (state["h"].transpose(
                0, 1).contiguous(), state["c"].transpose(0, 1).contiguous()))
        s = self.fc2(s[:, -1])
        # please ensure the first dim is batch size: [bsz, len, ...]
        return s, {
            "h": h.transpose(0, 1).detach(),
            "c": c.transpose(0, 1).detach()
        }
Exemplo n.º 12
0
    def forward(
        self,
        batch: Batch,
        state: Optional[Union[dict, Batch, np.ndarray]] = None,
        **kwargs: Any,
    ) -> Batch:
        """Compute action over the given batch data."""
        # There is "obs" in the Batch
        # obs_group: several groups. Each group has a state.
        obs_group: torch.Tensor = to_torch(  # type: ignore
            batch.obs, device=self.device)
        act_group = []
        for obs in obs_group:
            # now obs is (state_dim)
            obs = (obs.reshape(1, -1)).repeat(self.forward_sampled_times, 1)
            # now obs is (forward_sampled_times, state_dim)

            # decode(obs) generates action and actor perturbs it
            act = self.actor(obs, self.vae.decode(obs))
            # now action is (forward_sampled_times, action_dim)
            q1 = self.critic1(obs, act)
            # q1 is (forward_sampled_times, 1)
            max_indice = q1.argmax(0)
            act_group.append(act[max_indice].cpu().data.numpy().flatten())
        act_group = np.array(act_group)
        return Batch(act=act_group)
Exemplo n.º 13
0
 def forward(self, s, state=None, info={}):
     """In the evaluation mode, s should be with shape ``[bsz, dim]``; in
     the training mode, s should be with shape ``[bsz, len, dim]``. See the
     code and comment for more detail.
     """
     s = to_torch(s, device=self.device, dtype=torch.float32)
     # s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
     # In short, the tensor's shape in training phase is longer than which
     # in evaluation phase.
     if len(s.shape) == 2:
         s = s.unsqueeze(-2)
     s = self.fc1(s)
     self.nn.flatten_parameters()
     if state is None:
         s, (h, c) = self.nn(s)
     else:
         # we store the stack data in [bsz, len, ...] format
         # but pytorch rnn needs [len, bsz, ...]
         s, (h, c) = self.nn(s, (state['h'].transpose(
             0, 1).contiguous(), state['c'].transpose(0, 1).contiguous()))
     s = self.fc2(s[:, -1])
     # please ensure the first dim is batch size: [bsz, len, ...]
     return s, {
         'h': h.transpose(0, 1).detach(),
         'c': c.transpose(0, 1).detach()
     }
Exemplo n.º 14
0
    def forward(self, s, state=None, info={}):
        self.device = next(self.parameters()).device
        s = to_torch(s, device=self.device, dtype=torch.float32)
        s = s.flatten(1)

        bs = s.shape[0]
        w1 = self.model_w1(
            torch.cat([
                s[:, effective_dim_start:effective_dim_end],
                s[:, self.n + effective_dim_start:self.n + effective_dim_end]
            ],
                      dim=1)).reshape(bs, -1, self.n)
        w2 = self.model_w2(
            torch.cat([
                s[:, effective_dim_start:effective_dim_end],
                s[:, self.n + effective_dim_start:self.n + effective_dim_end]
            ],
                      dim=1)).reshape(bs, self.m, -1)
        mu = w2.matmul(
            torch.tanh(w1.matmul(
                (s[:, :self.n] - s[:, self.n:]).unsqueeze(-1)))).squeeze(-1)

        shape = [1] * len(mu.shape)
        shape[1] = -1
        sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp()

        return (mu, sigma), None
Exemplo n.º 15
0
 def forward(self, obs, **kwargs):
     obs = to_torch(obs, device=self.linear.weight.device)
     out = self.linear(self.preprocess(obs))
     # to take care of choices with different number of options
     mask = torch.arange(self.action_dim).expand(len(out), self.action_dim) >= obs['action_dim'].unsqueeze(1)
     out[mask.to(out.device)] = float('-inf')
     return nn.functional.softmax(out, dim=-1), kwargs.get('state', None)
Exemplo n.º 16
0
    def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
        if self._iter % self._freq == 0:
            self.sync_weight()
        self._iter += 1

        target_q = batch.returns.flatten()
        result = self(batch)
        imitation_logits = result.imitation_logits
        current_q = result.q_value[np.arange(len(target_q)), batch.act]
        act = to_torch(batch.act, dtype=torch.long, device=target_q.device)
        q_loss = F.smooth_l1_loss(current_q, target_q)
        i_loss = F.nll_loss(
            F.log_softmax(imitation_logits, dim=-1),
            act  # type: ignore
        )
        reg_loss = imitation_logits.pow(2).mean()
        loss = q_loss + i_loss + self._weight_reg * reg_loss

        self.optim.zero_grad()
        loss.backward()
        self.optim.step()

        return {
            "loss": loss.item(),
            "loss/q": q_loss.item(),
            "loss/i": i_loss.item(),
            "loss/reg": reg_loss.item(),
        }
Exemplo n.º 17
0
 def forward(self, s, state=None, info={}):
     s = to_torch(s, device=self.device, dtype=torch.float)
     # s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
     # In short, the tensor's shape in training phase is longer than which
     # in evaluation phase.
     if len(s.shape) == 2:
         bsz, dim = s.shape
         length = 1
     else:
         bsz, length, dim = s.shape
     s = self.fc1(s.view([bsz * length, dim]))
     s = s.view(bsz, length, -1)
     self.nn.flatten_parameters()
     if state is None:
         s, (h, c) = self.nn(s)
     else:
         # we store the stack data in [bsz, len, ...] format
         # but pytorch rnn needs [len, bsz, ...]
         s, (h, c) = self.nn(s, (state['h'].transpose(
             0, 1).contiguous(), state['c'].transpose(0, 1).contiguous()))
     s = self.fc2(s[:, -1])
     # please ensure the first dim is batch size: [bsz, len, ...]
     return s, {
         'h': h.transpose(0, 1).detach(),
         'c': c.transpose(0, 1).detach()
     }
Exemplo n.º 18
0
 def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
     if self._target and self._iter % self._freq == 0:
         self.sync_weight()
     self.optim.zero_grad()
     weight = batch.pop("weight", 1.0)
     all_dist = self(batch).logits
     act = to_torch(batch.act, dtype=torch.long, device=all_dist.device)
     curr_dist = all_dist[np.arange(len(act)), act, :].unsqueeze(2)
     target_dist = batch.returns.unsqueeze(1)
     # calculate each element's difference between curr_dist and target_dist
     u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
     huber_loss = (
         u * (self.tau_hat -
              (target_dist - curr_dist).detach().le(0.).float()).abs()
     ).sum(-1).mean(1)
     qr_loss = (huber_loss * weight).mean()
     # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
     # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
     batch.weight = u.detach().abs().sum(-1).mean(1)  # prio-buffer
     # add CQL loss
     q = self.compute_q_value(all_dist, None)
     dataset_expec = q.gather(1, act.unsqueeze(1)).mean()
     negative_sampling = q.logsumexp(1).mean()
     min_q_loss = negative_sampling - dataset_expec
     loss = qr_loss + min_q_loss * self._min_q_weight
     loss.backward()
     self.optim.step()
     self._iter += 1
     return {
         "loss": loss.item(),
         "loss/qr": qr_loss.item(),
         "loss/cql": min_q_loss.item(),
     }
Exemplo n.º 19
0
 def forward(self, s, state=None, info={}):
     """Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
     s = to_torch(s, device=self.device, dtype=torch.float32)
     # s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
     # In short, the tensor's shape in training phase is longer than which
     # in evaluation phase.
     if len(s.shape) == 2:
         s = s.unsqueeze(-2)
     self.nn.flatten_parameters()
     if state is None:
         s, (h, c) = self.nn(s)
     else:
         # we store the stack data in [bsz, len, ...] format
         # but pytorch rnn needs [len, bsz, ...]
         s, (h, c) = self.nn(s, (state['h'].transpose(
             0, 1).contiguous(), state['c'].transpose(0, 1).contiguous()))
     logits = s[:, -1]
     mu = self.mu(logits)
     if not self._unbounded:
         mu = self._max * torch.tanh(mu)
     shape = [1] * len(mu.shape)
     shape[1] = -1
     sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp()
     # please ensure the first dim is batch size: [bsz, len, ...]
     return (mu, sigma), {
         'h': h.transpose(0, 1).detach(),
         'c': c.transpose(0, 1).detach()
     }
    def forward(
        self,
        s: Union[np.ndarray, torch.Tensor],
        state: Optional[Any] = None,
        info: Dict[str, Any] = {},
    ) -> Tuple[torch.Tensor, Any]:
        r"""Mapping: s -> Q(s, \*)."""
        s = to_torch(s, device=self.device, dtype=torch.float32)
        logits, h = self.preprocess(s, state)
        logits = self.last(logits)
        bs, action_num = s.size()

        mask = torch.ones_like(logits)

        if 'turn' in info.keys():
            for i in range(bs):
                if info['turn'][i] < self.least_ask and np.sum(
                        info['history'][i][self.disease_num:]) < (
                            action_num - self.disease_num):
                    mask[i][:self.disease_num] = 0
            if 'history' in info.keys():
                mask = mask * torch.tensor(
                    np.ones_like(info['history']) -
                    info['history'])  # 这里有个乘history的操作

        if self.softmax_output:
            logits = torch.where(mask == 0, torch.full_like(logits, -1e16),
                                 logits)
            logits = F.softmax(logits, dim=-1)

        return logits, h
Exemplo n.º 21
0
    def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
        weight = batch.pop("weight", 1.0)
        target_q = batch.returns.flatten()
        act = to_torch(batch.act[:, np.newaxis],
                       device=target_q.device,
                       dtype=torch.long)

        # critic 1
        current_q1 = self.critic1(batch.obs).gather(1, act).flatten()
        td1 = current_q1 - target_q
        critic1_loss = (td1.pow(2) * weight).mean()

        self.critic1_optim.zero_grad()
        critic1_loss.backward()
        self.critic1_optim.step()

        # critic 2
        current_q2 = self.critic2(batch.obs).gather(1, act).flatten()
        td2 = current_q2 - target_q
        critic2_loss = (td2.pow(2) * weight).mean()

        self.critic2_optim.zero_grad()
        critic2_loss.backward()
        self.critic2_optim.step()
        batch.weight = (td1 + td2) / 2.0  # prio-buffer

        # actor
        dist = self(batch).dist
        entropy = dist.entropy()
        with torch.no_grad():
            current_q1a = self.critic1(batch.obs)
            current_q2a = self.critic2(batch.obs)
            q = torch.min(current_q1a, current_q2a)
        actor_loss = -(self._alpha * entropy +
                       (dist.probs * q).sum(dim=-1)).mean()
        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()

        if self._is_auto_alpha:
            log_prob = -entropy.detach() + self._target_entropy
            alpha_loss = -(self._log_alpha * log_prob).mean()
            self._alpha_optim.zero_grad()
            alpha_loss.backward()
            self._alpha_optim.step()
            self._alpha = self._log_alpha.detach().exp()

        self.sync_weight()

        result = {
            "loss/actor": actor_loss.item(),
            "loss/critic1": critic1_loss.item(),
            "loss/critic2": critic2_loss.item(),
        }
        if self._is_auto_alpha:
            result["loss/alpha"] = alpha_loss.item()
            result["alpha"] = self._alpha.item()  # type: ignore

        return result
Exemplo n.º 22
0
 def learn(self, batch: Batch, batch_size: int, repeat: int,
           **kwargs) -> Dict[str, List[float]]:
     losses = []
     r = batch.returns
     if self._rew_norm and not np.isclose(r.std(), 0):
         batch.returns = (r - r.mean()) / r.std()
     for _ in range(repeat):
         for b in batch.split(batch_size):
             self.optim.zero_grad()
             dist = self(b).dist
             a = to_torch(b.act, device=dist.logits.device)
             r = to_torch(b.returns, device=dist.logits.device)
             loss = -(dist.log_prob(a) * r).sum()
             loss.backward()
             self.optim.step()
             losses.append(loss.item())
     return {'loss': losses}
Exemplo n.º 23
0
    def forward(
        self,
        s: Union[np.ndarray, torch.Tensor],
        state: Optional[Any] = None,
        info: Dict[str, Any] = {},
    ) -> Tuple[torch.Tensor, Any]:
        """Mapping: s -> flatten -> logits."""
        if type(s) is tuple:
            s_0 = s[0]
            s_1 = s[1]
        elif type(s) is np.ndarray:
            if s.dtype == object:
                s_0 = s[:, 0]
                s_1 = s[:, 1]
        else:
            raise ValueError("No type %s!" % type(s))

        if not self.use_cam_obs:
            if self.use_phy_obs:
                robot_state = to_torch(np.stack(s_0),
                                       device=self.device,
                                       dtype=torch.float32)
                physical_state = to_torch(np.stack(s_1),
                                          device=self.device,
                                          dtype=torch.float32)
                logits = self.model([robot_state, physical_state])
            else:
                s = to_torch(s, device=self.device, dtype=torch.float32)
                s = s.reshape(s.size(0), -1)
                logits = self.model(s)
        else:
            img_top = to_torch(np.stack(s_0).transpose((0, 3, 1, 2)),
                               device=self.device,
                               dtype=torch.float32)
            robot_state = to_torch(np.stack(s_1),
                                   device=self.device,
                                   dtype=torch.float32)
            logits = self.model([img_top, robot_state])
        if self.dueling is not None:  # Dueling DQN
            q, v = self.Q(logits), self.V(logits)
            logits = q - q.mean(dim=1, keepdim=True) + v
        if self.softmax:
            logits = torch.softmax(logits, dim=-1)
        return logits, state
Exemplo n.º 24
0
    def train_estimator(self,
                        init_critic=None,
                        discount=0.99,
                        target_update_period=100,
                        critic_lr=1e-4,
                        num_steps=250000,
                        polyak=0.0,
                        batch_size=256):

        min_reward = self.buffer.rew.min()
        max_reward = self.buffer.rew.max()

        max_value = (1.2 * max_reward + 0.8 * min_reward) / (1 - discount)
        min_value = (1.2 * min_reward + 0.8 * max_reward) / (1 - discount)

        data = self.buffer.sample(batch_size)
        input_dim = data.obs.shape[-1] + data.act.shape[-1]
        critic = MLP(input_dim, 1, self.critic_hidden_features,
                     self.critic_hidden_layers).to(self._device)
        if init_critic is not None:
            critic.load_state_dict(init_critic.state_dict())
        critic_optimizer = torch.optim.Adam(critic.parameters(), lr=critic_lr)
        target_critic = deepcopy(critic).to(self._device)
        target_critic.requires_grad_(False)

        print('Training Fqe...')
        for t in tqdm(range(num_steps)):
            batch = self.buffer.sample(batch_size)
            data = to_torch(batch, torch.float32, device=self._device)
            r = data.rew
            terminals = data.done
            o1 = data.obs
            a1 = data.act

            o2 = data.obs_next
            a2 = self.policy.get_action(o2)
            q_target = target_critic(torch.cat((o2, a2), -1)).detach()
            current_discount = discount * (1 - terminals)
            backup = r + current_discount * q_target
            backup = torch.clamp(backup, min_value,
                                 max_value)  # prevent explosion

            q = critic(torch.cat((o1, a1), -1))
            critic_loss = ((q - backup)**2).mean()

            critic_optimizer.zero_grad()
            critic_loss.backward()
            critic_optimizer.step()

            if t % target_update_period == 0:
                with torch.no_grad():
                    for p, p_targ in zip(critic.parameters(),
                                         target_critic.parameters()):
                        p_targ.data.mul_(polyak)
                        p_targ.data.add_((1 - polyak) * p.data)
        return critic
Exemplo n.º 25
0
 def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
     with torch.no_grad():
         obs_next_result = self(batch, input='obs_next')
         a_ = obs_next_result.act
         dev = a_.device
         batch.act = to_torch(batch.act, dtype=torch.float, device=dev)
         target_q = torch.min(
             self.critic1_old(batch.obs_next, a_),
             self.critic2_old(batch.obs_next, a_),
         ) - self._alpha * obs_next_result.log_prob
         rew = to_torch(batch.rew,
                        dtype=torch.float, device=dev)[:, None]
         done = to_torch(batch.done,
                         dtype=torch.float, device=dev)[:, None]
         target_q = (rew + (1. - done) * self._gamma * target_q)
     # critic 1
     current_q1 = self.critic1(batch.obs, batch.act)
     critic1_loss = F.mse_loss(current_q1, target_q)
     self.critic1_optim.zero_grad()
     critic1_loss.backward()
     self.critic1_optim.step()
     # critic 2
     current_q2 = self.critic2(batch.obs, batch.act)
     critic2_loss = F.mse_loss(current_q2, target_q)
     self.critic2_optim.zero_grad()
     critic2_loss.backward()
     self.critic2_optim.step()
     # actor
     obs_result = self(batch)
     a = obs_result.act
     current_q1a = self.critic1(batch.obs, a)
     current_q2a = self.critic2(batch.obs, a)
     actor_loss = (self._alpha * obs_result.log_prob - torch.min(
         current_q1a, current_q2a)).mean()
     self.actor_optim.zero_grad()
     actor_loss.backward()
     self.actor_optim.step()
     self.sync_weight()
     return {
         'loss/actor': actor_loss.item(),
         'loss/critic1': critic1_loss.item(),
         'loss/critic2': critic2_loss.item(),
     }
Exemplo n.º 26
0
 def forward(
     self,
     x: Union[np.ndarray, torch.Tensor],
     state: Optional[Any] = None,
     info: Dict[str, Any] = {},
 ) -> Tuple[torch.Tensor, Any]:
     r"""Mapping: x -> Q(x, \*)."""
     if not isinstance(x, torch.Tensor):
         x = to_torch(x, device=self.device, dtype=torch.float32)
     return self.net(x), state
Exemplo n.º 27
0
 def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
     with torch.no_grad():
         a_ = self(batch, model='actor_old', input='obs_next').act
         dev = a_.device
         noise = torch.randn(size=a_.shape, device=dev) * self._policy_noise
         if self._noise_clip >= 0:
             noise = noise.clamp(-self._noise_clip, self._noise_clip)
         a_ += noise
         a_ = a_.clamp(self._range[0], self._range[1])
         target_q = torch.min(
             self.critic1_old(batch.obs_next, a_),
             self.critic2_old(batch.obs_next, a_))
         rew = to_torch(batch.rew,
                        dtype=torch.float, device=dev)[:, None]
         done = to_torch(batch.done,
                         dtype=torch.float, device=dev)[:, None]
         target_q = (rew + (1. - done) * self._gamma * target_q)
     # critic 1
     current_q1 = self.critic1(batch.obs, batch.act)
     critic1_loss = F.mse_loss(current_q1, target_q)
     self.critic1_optim.zero_grad()
     critic1_loss.backward()
     self.critic1_optim.step()
     # critic 2
     current_q2 = self.critic2(batch.obs, batch.act)
     critic2_loss = F.mse_loss(current_q2, target_q)
     self.critic2_optim.zero_grad()
     critic2_loss.backward()
     self.critic2_optim.step()
     if self._cnt % self._freq == 0:
         actor_loss = -self.critic1(
             batch.obs, self(batch, eps=0).act).mean()
         self.actor_optim.zero_grad()
         actor_loss.backward()
         self._last = actor_loss.item()
         self.actor_optim.step()
         self.sync_weight()
     self._cnt += 1
     return {
         'loss/actor': self._last,
         'loss/critic1': critic1_loss.item(),
         'loss/critic2': critic2_loss.item(),
     }
Exemplo n.º 28
0
 def forward(self, s, state=None, info={}):
     """s -> flatten -> logits"""
     s = to_torch(s, device=self.device, dtype=torch.float32)
     s = s.reshape(s.size(0), -1)
     logits = self.model(s)
     if self.dueling is not None:  # Dueling DQN
         q, v = self.Q(logits), self.V(logits)
         logits = q - q.mean(dim=1, keepdim=True) + v
     if self.softmax:
         logits = torch.softmax(logits, dim=-1)
     return logits, state
Exemplo n.º 29
0
 def forward(self, s, **kwargs):
     s = to_torch(s, device=self.device, dtype=torch.float32)
     # s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
     # In short, the tensor's shape in training phase is longer than which
     # in evaluation phase.
     if len(s.shape) == 2:
         s = s.unsqueeze(-2)
     logits, _ = self.nn(s)
     logits = logits[:, -1]
     mu = self.mu(logits)
     shape = [1] * len(mu.shape)
     shape[1] = -1
     sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp()
     return (mu, sigma), None
Exemplo n.º 30
0
 def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
     with torch.no_grad():
         target_q = self.critic_old(
             batch.obs_next,
             self(batch, model='actor_old', input='obs_next', eps=0).act)
         dev = target_q.device
         rew = to_torch(batch.rew, dtype=torch.float, device=dev)[:, None]
         done = to_torch(batch.done, dtype=torch.float, device=dev)[:, None]
         target_q = (rew + (1. - done) * self._gamma * target_q)
     current_q = self.critic(batch.obs, batch.act)
     critic_loss = F.mse_loss(current_q, target_q)
     self.critic_optim.zero_grad()
     critic_loss.backward()
     self.critic_optim.step()
     actor_loss = -self.critic(batch.obs, self(batch, eps=0).act).mean()
     self.actor_optim.zero_grad()
     actor_loss.backward()
     self.actor_optim.step()
     self.sync_weight()
     return {
         'loss/actor': actor_loss.item(),
         'loss/critic': critic_loss.item(),
     }