Example #1
0
    def __init__(self, input_dim, n_nodes, node_dim):
        super(GraphVAE, self).__init__()
        # store parameters
        self.input_dim = input_dim
        self.n_nodes = n_nodes
        self.node_dim = node_dim

        # encoder: x -> h_x
        self.encoder = nn.Sequential(nn.Linear(input_dim, 512),
                                     nn.BatchNorm1d(512), nn.ELU(),
                                     nn.Linear(512, 512), nn.BatchNorm1d(512),
                                     nn.ELU(), nn.Linear(512, 256),
                                     nn.BatchNorm1d(256), nn.ELU(),
                                     nn.Linear(256, 128))
        # bottom-up inference: predicts parameters of P(z_i | x)
        self.bottom_up = nn.ModuleList([
            nn.Sequential(
                nn.Linear(128, 128),
                nn.BatchNorm1d(128),
                nn.ELU(),
                nn.Linear(128, node_dim),
                nn.Linear(node_dim, 2 * node_dim)  # split into mu and logvar
            ) for _ in range(n_nodes - 1)
        ])  # ignore z_n

        # top-down inference: predicts parameters of P(z_i | Pa(z_i))
        self.top_down = nn.ModuleList([
            nn.Sequential(
                nn.Linear((n_nodes - i - 1) * node_dim,
                          128),  # parents of z_i are z_{i+1} ... z_N
                nn.BatchNorm1d(128),
                nn.ELU(),
                nn.Linear(128, node_dim),
                nn.Linear(node_dim, 2 * node_dim)  # split into mu and logvar
            ) for i in range(n_nodes - 1)
        ])  # ignore z_n

        # decoder: (z_1, z_2 ... z_n) -> parameters of P(x)
        self.decoder = nn.Sequential(nn.Linear(node_dim * n_nodes, 256),
                                     nn.BatchNorm1d(256), nn.ELU(),
                                     nn.Linear(256, 512), nn.BatchNorm1d(512),
                                     nn.ELU(), nn.Linear(512, 512),
                                     nn.BatchNorm1d(512), nn.ELU(),
                                     nn.Linear(512, input_dim))

        # mean of Bernoulli variables c_{i,j} representing edges
        self.gating_params = nn.ParameterList([
            nn.Parameter(torch.empty(n_nodes - i - 1, 1, 1).fill_(0.5),
                         requires_grad=True) for i in range(n_nodes - 1)
        ])  # ignore z_n

        # distributions for sampling
        self.unit_normal = D.Normal(torch.zeros(self.node_dim),
                                    torch.ones(self.node_dim))
        self.gumbel = D.Gumbel(0., 1.)

        # other parameters / distributions
        self.tau = 1.0
Example #2
0
    def _train(self, BATCH):
        if self.is_continuous:
            action_target = self.actor.t(
                BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
            if self.use_target_action_noise:
                action_target = self.target_noised_action(
                    action_target)  # [T, B, A]
        else:
            target_logits = self.actor.t(
                BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
            target_cate_dist = td.Categorical(logits=target_logits)
            target_pi = target_cate_dist.sample()  # [T, B]
            action_target = F.one_hot(target_pi,
                                      self.a_dim).float()  # [T, B, A]
        q = self.critic(BATCH.obs, BATCH.action,
                        begin_mask=BATCH.begin_mask)  # [T, B, 1]
        q_target = self.critic.t(BATCH.obs_,
                                 action_target,
                                 begin_mask=BATCH.begin_mask)  # [T, B, 1]
        dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, q_target,
                             BATCH.begin_mask).detach()  # [T, B, 1]
        td_error = dc_r - q  # [T, B, 1]
        q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean()  # 1
        self.critic_oplr.optimize(q_loss)

        if self.is_continuous:
            mu = self.actor(BATCH.obs,
                            begin_mask=BATCH.begin_mask)  # [T, B, A]
        else:
            logits = self.actor(BATCH.obs,
                                begin_mask=BATCH.begin_mask)  # [T, B, A]
            logp_all = logits.log_softmax(-1)  # [T, B, A]
            gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape)  # [T, B, A]
            _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax(
                -1)  # [T, B, A]
            _pi_true_one_hot = F.one_hot(_pi.argmax(-1),
                                         self.a_dim).float()  # [T, B, A]
            _pi_diff = (_pi_true_one_hot - _pi).detach()  # [T, B, A]
            mu = _pi_diff + _pi  # [T, B, A]
        q_actor = self.critic(BATCH.obs, mu,
                              begin_mask=BATCH.begin_mask)  # [T, B, 1]
        actor_loss = -q_actor.mean()  # 1
        self.actor_oplr.optimize(actor_loss)

        return td_error, {
            'LEARNING_RATE/actor_lr': self.actor_oplr.lr,
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr,
            'LOSS/actor_loss': actor_loss,
            'LOSS/critic_loss': q_loss,
            'Statistics/q_min': q.min(),
            'Statistics/q_mean': q.mean(),
            'Statistics/q_max': q.max()
        }
Example #3
0
    def sample(self, sample_shape=torch.Size()):    
        if sample_shape is not None:
            sample_shape = torch.Size(sample_shape)

        # In comments, I use S as an indication of dimension(s) related to sample_shape
        # and B as an indication of dimension(s) related to batch_shape
        with torch.no_grad():
            batch_shape, K = self.batch_shape, self._K  
            # This will store the sequence of labels from k=0 to K+1
            # [S, B, K+2]
            L = torch.zeros(sample_shape + batch_shape + (K+2,), device=self._scores.device).long()            
            # [L, B, K+2, 3]
            eps = td.Gumbel(
                    loc=torch.zeros(L.shape + (3,), device=self._scores.device), 
                    scale=torch.ones(L.shape + (3,), device=self._scores.device)
            ).sample()
            # [...,K+1,3,3]
            W = self._arc_weight
            # [...,K+2,3]
            V = self._state_value
            for k in torch.arange(K+1, device=self._scores.device): 
                # weights of arcs leaving this coordinate
                # [B, 3, 3]
                W_k = W[...,k,:,:]
                # reshape to introduce sample_shape dimensions
                # [S, B, 3, 3]
                W_k = W_k.view((1,) * len(sample_shape) + W_k.shape).expand(sample_shape + (-1,)*len(W_k.shape))
                # origin state for coordinate k
                # [S, B]
                L_k = L[...,k]
                # reshape to a 3-dimensional one-hot encoding of the label 
                # [S, B, 3, 1]
                L_k = torch.nn.functional.one_hot(L_k, 3).unsqueeze(-1)
                # select the weights for destination (zeroing out the rest)
                # [S, B, 3, 3]
                logits_k = torch.where(L_k == 1, W_k, torch.zeros_like(W_k))
                # sum 0s out and incorporate value of destination
                # [S, B, 3]
                logits_k = logits_k.sum(-2) + V[...,k+1,:] 

                # Categorical sampling via Gumbel-Argmax
                #  possibly more efficient than td.Categorical(logits=logits_k).sample().long()
                L[...,k+1] = torch.argmax(logits_k + eps[...,k+1,:], -1).long()                

            assert (L[...,-1] == 1).all(), "Not every sample reached the final state"
            L = L[...,1:-1]  # discard the initial (k=0) and final (k=K+1) states
            # map to boolean and then float (in torch discrete samples are float)
            return (L==2).float()
Example #4
0
    def _train(self, BATCH_DICT):
        """
        TODO: Annotation
        """
        summaries = defaultdict(dict)
        target_actions = {}
        for aid, mid in zip(self.agent_ids, self.model_ids):
            if self.is_continuouss[aid]:
                target_actions[aid] = self.actors[mid].t(
                    BATCH_DICT[aid].obs_,
                    begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]
            else:
                target_logits = self.actors[mid].t(
                    BATCH_DICT[aid].obs_,
                    begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]
                target_cate_dist = td.Categorical(logits=target_logits)
                target_pi = target_cate_dist.sample()  # [T, B]
                action_target = F.one_hot(
                    target_pi, self.a_dims[aid]).float()  # [T, B, A]
                target_actions[aid] = action_target  # [T, B, A]
        target_actions = th.cat(list(target_actions.values()),
                                -1)  # [T, B, N*A]

        qs, q_targets = {}, {}
        for mid in self.model_ids:
            qs[mid] = self.critics[mid](
                [BATCH_DICT[id].obs for id in self.agent_ids],
                th.cat([BATCH_DICT[id].action for id in self.agent_ids],
                       -1))  # [T, B, 1]
            q_targets[mid] = self.critics[mid].t(
                [BATCH_DICT[id].obs_ for id in self.agent_ids],
                target_actions)  # [T, B, 1]

        q_loss = {}
        td_errors = 0.
        for aid, mid in zip(self.agent_ids, self.model_ids):
            dc_r = n_step_return(
                BATCH_DICT[aid].reward, self.gamma, BATCH_DICT[aid].done,
                q_targets[mid],
                BATCH_DICT['global'].begin_mask).detach()  # [T, B, 1]
            td_error = dc_r - qs[mid]  # [T, B, 1]
            td_errors += td_error
            q_loss[aid] = 0.5 * td_error.square().mean()  # 1
            summaries[aid].update({
                'Statistics/q_min': qs[mid].min(),
                'Statistics/q_mean': qs[mid].mean(),
                'Statistics/q_max': qs[mid].max()
            })
        self.critic_oplr.optimize(sum(q_loss.values()))

        actor_loss = {}
        for aid, mid in zip(self.agent_ids, self.model_ids):
            if self.is_continuouss[aid]:
                mu = self.actors[mid](
                    BATCH_DICT[aid].obs,
                    begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]
            else:
                logits = self.actors[mid](
                    BATCH_DICT[aid].obs,
                    begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]
                logp_all = logits.log_softmax(-1)  # [T, B, A]
                gumbel_noise = td.Gumbel(0,
                                         1).sample(logp_all.shape)  # [T, B, A]
                _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax(
                    -1)  # [T, B, A]
                _pi_true_one_hot = F.one_hot(
                    _pi.argmax(-1), self.a_dims[aid]).float()  # [T, B, A]
                _pi_diff = (_pi_true_one_hot - _pi).detach()  # [T, B, A]
                mu = _pi_diff + _pi  # [T, B, A]

            all_actions = {id: BATCH_DICT[id].action for id in self.agent_ids}
            all_actions[aid] = mu
            q_actor = self.critics[mid](
                [BATCH_DICT[id].obs for id in self.agent_ids],
                th.cat(list(all_actions.values()), -1),
                begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, 1]
            actor_loss[aid] = -q_actor.mean()  # 1

        self.actor_oplr.optimize(sum(actor_loss.values()))

        for aid in self.agent_ids:
            summaries[aid].update({
                'LOSS/actor_loss': actor_loss[aid],
                'LOSS/critic_loss': q_loss[aid]
            })
        summaries['model'].update({
            'LOSS/actor_loss',
            sum(actor_loss.values()), 'LOSS/critic_loss',
            sum(q_loss.values())
        })
        return td_errors / self.n_agents_percopy, summaries
Example #5
0
    def _train(self, BATCH):
        if self.is_continuous:
            target_mu, target_log_std = self.actor(BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
            dist = td.Independent(td.Normal(target_mu, target_log_std.exp()), 1)
            target_pi = dist.sample()  # [T, B, A]
            target_pi, target_log_pi = squash_action(target_pi, dist.log_prob(
                target_pi).unsqueeze(-1), is_independent=False)  # [T, B, A]
            target_log_pi = tsallis_entropy_log_q(target_log_pi, self.entropic_index)  # [T, B, 1]
        else:
            target_logits = self.actor(BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
            target_cate_dist = td.Categorical(logits=target_logits)
            target_pi = target_cate_dist.sample()  # [T, B]
            target_log_pi = target_cate_dist.log_prob(target_pi).unsqueeze(-1)  # [T, B, 1]
            target_pi = F.one_hot(target_pi, self.a_dim).float()  # [T, B, A]
        q1 = self.critic(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask)  # [T, B, 1]
        q2 = self.critic2(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask)  # [T, B, 1]

        q1_target = self.critic.t(BATCH.obs_, target_pi, begin_mask=BATCH.begin_mask)  # [T, B, 1]
        q2_target = self.critic2.t(BATCH.obs_, target_pi, begin_mask=BATCH.begin_mask)  # [T, B, 1]
        q_target = th.minimum(q1_target, q2_target)  # [T, B, 1]
        dc_r = n_step_return(BATCH.reward,
                             self.gamma,
                             BATCH.done,
                             (q_target - self.alpha * target_log_pi),
                             BATCH.begin_mask).detach()  # [T, B, 1]
        td_error1 = q1 - dc_r  # [T, B, 1]
        td_error2 = q2 - dc_r  # [T, B, 1]

        q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean()  # 1
        q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean()  # 1
        critic_loss = 0.5 * q1_loss + 0.5 * q2_loss
        self.critic_oplr.optimize(critic_loss)

        if self.is_continuous:
            mu, log_std = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            pi = dist.rsample()  # [T, B, A]
            pi, log_pi = squash_action(pi, dist.log_prob(pi).unsqueeze(-1), is_independent=False)  # [T, B, A]
            log_pi = tsallis_entropy_log_q(log_pi, self.entropic_index)  # [T, B, 1]
            entropy = dist.entropy().mean()  # 1
        else:
            logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, A]
            logp_all = logits.log_softmax(-1)  # [T, B, A]
            gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape)  # [T, B, A]
            _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax(-1)  # [T, B, A]
            _pi_true_one_hot = F.one_hot(_pi.argmax(-1), self.a_dim).float()  # [T, B, A]
            _pi_diff = (_pi_true_one_hot - _pi).detach()  # [T, B, A]
            pi = _pi_diff + _pi  # [T, B, A]
            log_pi = (logp_all * pi).sum(-1, keepdim=True)  # [T, B, 1]
            entropy = -(logp_all.exp() * logp_all).sum(-1).mean()  # 1
        q_s_pi = th.minimum(self.critic(BATCH.obs, pi, begin_mask=BATCH.begin_mask),
                            self.critic2(BATCH.obs, pi, begin_mask=BATCH.begin_mask))  # [T, B, 1]
        actor_loss = -(q_s_pi - self.alpha * log_pi).mean()  # 1
        self.actor_oplr.optimize(actor_loss)

        summaries = {
            'LEARNING_RATE/actor_lr': self.actor_oplr.lr,
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr,
            'LOSS/actor_loss': actor_loss,
            'LOSS/q1_loss': q1_loss,
            'LOSS/q2_loss': q2_loss,
            'LOSS/critic_loss': critic_loss,
            'Statistics/log_alpha': self.log_alpha,
            'Statistics/alpha': self.alpha,
            'Statistics/entropy': entropy,
            'Statistics/q_min': th.minimum(q1, q2).min(),
            'Statistics/q_mean': th.minimum(q1, q2).mean(),
            'Statistics/q_max': th.maximum(q1, q2).max()
        }
        if self.auto_adaption:
            alpha_loss = -(self.alpha * (log_pi + self.target_entropy).detach()).mean()  # 1
            self.alpha_oplr.optimize(alpha_loss)
            summaries.update({
                'LOSS/alpha_loss': alpha_loss,
                'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr
            })
        return (td_error1 + td_error2) / 2, summaries
Example #6
0
    def _train(self, BATCH):
        for _ in range(self.delay_num):
            if self.is_continuous:
                action_target = self.target_noised_action(
                    self.actor.t(BATCH.obs_,
                                 begin_mask=BATCH.begin_mask))  # [T, B, A]
            else:
                target_logits = self.actor.t(
                    BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
                target_cate_dist = td.Categorical(logits=target_logits)
                target_pi = target_cate_dist.sample()  # [T, B]
                action_target = F.one_hot(target_pi,
                                          self.a_dim).float()  # [T, B, A]
            q1 = self.critic(BATCH.obs,
                             BATCH.action,
                             begin_mask=BATCH.begin_mask)  # [T, B, 1]
            q2 = self.critic2(BATCH.obs,
                              BATCH.action,
                              begin_mask=BATCH.begin_mask)  # [T, B, 1]
            q_target = th.minimum(
                self.critic.t(BATCH.obs_,
                              action_target,
                              begin_mask=BATCH.begin_mask),
                self.critic2.t(BATCH.obs_,
                               action_target,
                               begin_mask=BATCH.begin_mask))  # [T, B, 1]
            dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done,
                                 q_target,
                                 BATCH.begin_mask).detach()  # [T, B, 1]
            td_error1 = q1 - dc_r  # [T, B, 1]
            td_error2 = q2 - dc_r  # [T, B, 1]

            q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean()  # 1
            q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean()  # 1
            critic_loss = 0.5 * (q1_loss + q2_loss)
            self.critic_oplr.optimize(critic_loss)

        if self.is_continuous:
            mu = self.actor(BATCH.obs,
                            begin_mask=BATCH.begin_mask)  # [T, B, A]
        else:
            logits = self.actor(BATCH.obs,
                                begin_mask=BATCH.begin_mask)  # [T, B, A]
            logp_all = logits.log_softmax(-1)  # [T, B, A]
            gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape)  # [T, B, A]
            _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax(
                -1)  # [T, B, A]
            _pi_true_one_hot = F.one_hot(_pi.argmax(-1),
                                         self.a_dim).float()  # [T, B, A]
            _pi_diff = (_pi_true_one_hot - _pi).detach()  # [T, B, A]
            mu = _pi_diff + _pi  # [T, B, A]
        q1_actor = self.critic(BATCH.obs, mu,
                               begin_mask=BATCH.begin_mask)  # [T, B, 1]

        actor_loss = -q1_actor.mean()  # 1
        self.actor_oplr.optimize(actor_loss)
        return (td_error1 + td_error2) / 2, {
            'LEARNING_RATE/actor_lr': self.actor_oplr.lr,
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr,
            'LOSS/actor_loss': actor_loss,
            'LOSS/critic_loss': critic_loss,
            'Statistics/q_min': th.minimum(q1, q2).min(),
            'Statistics/q_mean': th.minimum(q1, q2).mean(),
            'Statistics/q_max': th.maximum(q1, q2).max()
        }
Example #7
0
    def _train_continuous(self, BATCH):
        v = self.v_net(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, 1]
        v_target = self.v_net.t(BATCH.obs_,
                                begin_mask=BATCH.begin_mask)  # [T, B, 1]

        if self.is_continuous:
            mu, log_std = self.actor(BATCH.obs,
                                     begin_mask=BATCH.begin_mask)  # [T, B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            pi = dist.rsample()  # [T, B, A]
            pi, log_pi = squash_action(
                pi,
                dist.log_prob(pi).unsqueeze(-1))  # [T, B, A], [T, B, 1]
        else:
            logits = self.actor(BATCH.obs,
                                begin_mask=BATCH.begin_mask)  # [T, B, A]
            logp_all = logits.log_softmax(-1)  # [T, B, A]
            gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape)  # [T, B, A]
            _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax(
                -1)  # [T, B, A]
            _pi_true_one_hot = F.one_hot(_pi.argmax(-1),
                                         self.a_dim).float()  # [T, B, A]
            _pi_diff = (_pi_true_one_hot - _pi).detach()  # [T, B, A]
            pi = _pi_diff + _pi  # [T, B, A]
            log_pi = (logp_all * pi).sum(-1, keepdim=True)  # [T, B, 1]
        q1 = self.q_net(BATCH.obs, BATCH.action,
                        begin_mask=BATCH.begin_mask)  # [T, B, 1]
        q2 = self.q_net2(BATCH.obs, BATCH.action,
                         begin_mask=BATCH.begin_mask)  # [T, B, 1]
        q1_pi = self.q_net(BATCH.obs, pi,
                           begin_mask=BATCH.begin_mask)  # [T, B, 1]
        q2_pi = self.q_net2(BATCH.obs, pi,
                            begin_mask=BATCH.begin_mask)  # [T, B, 1]
        dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, v_target,
                             BATCH.begin_mask).detach()  # [T, B, 1]
        v_from_q_stop = (th.minimum(q1_pi, q2_pi) -
                         self.alpha * log_pi).detach()  # [T, B, 1]
        td_v = v - v_from_q_stop  # [T, B, 1]
        td_error1 = q1 - dc_r  # [T, B, 1]
        td_error2 = q2 - dc_r  # [T, B, 1]
        q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean()  # 1
        q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean()  # 1
        v_loss_stop = (td_v.square() * BATCH.get('isw', 1.0)).mean()  # 1

        critic_loss = 0.5 * q1_loss + 0.5 * q2_loss + 0.5 * v_loss_stop
        self.critic_oplr.optimize(critic_loss)

        if self.is_continuous:
            mu, log_std = self.actor(BATCH.obs,
                                     begin_mask=BATCH.begin_mask)  # [T, B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            pi = dist.rsample()  # [T, B, A]
            pi, log_pi = squash_action(
                pi,
                dist.log_prob(pi).unsqueeze(-1))  # [T, B, A], [T, B, 1]
            entropy = dist.entropy().mean()  # 1
        else:
            logits = self.actor(BATCH.obs,
                                begin_mask=BATCH.begin_mask)  # [T, B, A]
            logp_all = logits.log_softmax(-1)  # [T, B, A]
            gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape)  # [T, B, A]
            _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax(
                -1)  # [T, B, A]
            _pi_true_one_hot = F.one_hot(_pi.argmax(-1),
                                         self.a_dim).float()  # [T, B, A]
            _pi_diff = (_pi_true_one_hot - _pi).detach()  # [T, B, A]
            pi = _pi_diff + _pi  # [T, B, A]
            log_pi = (logp_all * pi).sum(-1, keepdim=True)  # [T, B, 1]
            entropy = -(logp_all.exp() * logp_all).sum(-1).mean()  # 1
        q1_pi = self.q_net(BATCH.obs, pi,
                           begin_mask=BATCH.begin_mask)  # [T, B, 1]
        actor_loss = -(q1_pi - self.alpha * log_pi).mean()  # 1
        self.actor_oplr.optimize(actor_loss)

        summaries = {
            'LEARNING_RATE/actor_lr': self.actor_oplr.lr,
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr,
            'LOSS/actor_loss': actor_loss,
            'LOSS/q1_loss': q1_loss,
            'LOSS/q2_loss': q2_loss,
            'LOSS/v_loss': v_loss_stop,
            'LOSS/critic_loss': critic_loss,
            'Statistics/log_alpha': self.log_alpha,
            'Statistics/alpha': self.alpha,
            'Statistics/entropy': entropy,
            'Statistics/q_min': th.minimum(q1, q2).min(),
            'Statistics/q_mean': th.minimum(q1, q2).mean(),
            'Statistics/q_max': th.maximum(q1, q2).max(),
            'Statistics/v_mean': v.mean()
        }
        if self.auto_adaption:
            alpha_loss = -(self.alpha *
                           (log_pi.detach() + self.target_entropy)).mean()
            self.alpha_oplr.optimize(alpha_loss)
            summaries.update({
                'LOSS/alpha_loss': alpha_loss,
                'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr
            })
        return (td_error1 + td_error2) / 2, summaries
Example #8
0
 def __init__(self, u=0, b=1, t=.1, dim=-1):
     self.gumbel = distributions.Gumbel(loc=u, scale=b)
     self.temperature = t
     self.dim = dim
Example #9
0
    def _train(self, BATCH_DICT):
        """
        TODO: Annotation
        """
        summaries = defaultdict(dict)
        target_actions = {}
        target_log_pis = 1.
        for aid, mid in zip(self.agent_ids, self.model_ids):
            if self.is_continuouss[aid]:
                target_mu, target_log_std = self.actors[mid](
                    BATCH_DICT[aid].obs_,
                    begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]
                dist = td.Independent(
                    td.Normal(target_mu, target_log_std.exp()), 1)
                target_pi = dist.sample()  # [T, B, A]
                target_pi, target_log_pi = squash_action(
                    target_pi,
                    dist.log_prob(target_pi).unsqueeze(
                        -1))  # [T, B, A], [T, B, 1]
            else:
                target_logits = self.actors[mid](
                    BATCH_DICT[aid].obs_,
                    begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]
                target_cate_dist = td.Categorical(logits=target_logits)
                target_pi = target_cate_dist.sample()  # [T, B]
                target_log_pi = target_cate_dist.log_prob(target_pi).unsqueeze(
                    -1)  # [T, B, 1]
                target_pi = F.one_hot(target_pi,
                                      self.a_dims[aid]).float()  # [T, B, A]
            target_actions[aid] = target_pi
            target_log_pis *= target_log_pi

        target_log_pis += th.finfo().eps
        target_actions = th.cat(list(target_actions.values()),
                                -1)  # [T, B, N*A]

        qs1, qs2, q_targets1, q_targets2 = {}, {}, {}, {}
        for mid in self.model_ids:
            qs1[mid] = self.critics[mid](
                [BATCH_DICT[id].obs for id in self.agent_ids],
                th.cat([BATCH_DICT[id].action for id in self.agent_ids],
                       -1))  # [T, B, 1]
            qs2[mid] = self.critics2[mid](
                [BATCH_DICT[id].obs for id in self.agent_ids],
                th.cat([BATCH_DICT[id].action for id in self.agent_ids],
                       -1))  # [T, B, 1]
            q_targets1[mid] = self.critics[mid].t(
                [BATCH_DICT[id].obs_ for id in self.agent_ids],
                target_actions)  # [T, B, 1]
            q_targets2[mid] = self.critics2[mid].t(
                [BATCH_DICT[id].obs_ for id in self.agent_ids],
                target_actions)  # [T, B, 1]

        q_loss = {}
        td_errors = 0.
        for aid, mid in zip(self.agent_ids, self.model_ids):
            q_target = th.minimum(q_targets1[mid],
                                  q_targets2[mid])  # [T, B, 1]
            dc_r = n_step_return(
                BATCH_DICT[aid].reward, self.gamma, BATCH_DICT[aid].done,
                q_target - self.alpha * target_log_pis,
                BATCH_DICT['global'].begin_mask).detach()  # [T, B, 1]
            td_error1 = qs1[mid] - dc_r  # [T, B, 1]
            td_error2 = qs2[mid] - dc_r  # [T, B, 1]
            td_errors += (td_error1 + td_error2) / 2
            q1_loss = td_error1.square().mean()  # 1
            q2_loss = td_error2.square().mean()  # 1
            q_loss[aid] = 0.5 * q1_loss + 0.5 * q2_loss
            summaries[aid].update({
                'Statistics/q_min': qs1[mid].min(),
                'Statistics/q_mean': qs1[mid].mean(),
                'Statistics/q_max': qs1[mid].max()
            })
        self.critic_oplr.optimize(sum(q_loss.values()))

        log_pi_actions = {}
        log_pis = {}
        sample_pis = {}
        for aid, mid in zip(self.agent_ids, self.model_ids):
            if self.is_continuouss[aid]:
                mu, log_std = self.actors[mid](
                    BATCH_DICT[aid].obs,
                    begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]
                dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
                pi = dist.rsample()  # [T, B, A]
                pi, log_pi = squash_action(
                    pi,
                    dist.log_prob(pi).unsqueeze(-1))  # [T, B, A], [T, B, 1]
                pi_action = BATCH_DICT[aid].action.arctanh()
                _, log_pi_action = squash_action(
                    pi_action,
                    dist.log_prob(pi_action).unsqueeze(
                        -1))  # [T, B, A], [T, B, 1]
            else:
                logits = self.actors[mid](
                    BATCH_DICT[aid].obs,
                    begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]
                logp_all = logits.log_softmax(-1)  # [T, B, A]
                gumbel_noise = td.Gumbel(0,
                                         1).sample(logp_all.shape)  # [T, B, A]
                _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax(
                    -1)  # [T, B, A]
                _pi_true_one_hot = F.one_hot(
                    _pi.argmax(-1), self.a_dims[aid]).float()  # [T, B, A]
                _pi_diff = (_pi_true_one_hot - _pi).detach()  # [T, B, A]
                pi = _pi_diff + _pi  # [T, B, A]
                log_pi = (logp_all * pi).sum(-1, keepdim=True)  # [T, B, 1]
                log_pi_action = (logp_all * BATCH_DICT[aid].action).sum(
                    -1, keepdim=True)  # [T, B, 1]
            log_pi_actions[aid] = log_pi_action
            log_pis[aid] = log_pi
            sample_pis[aid] = pi

        actor_loss = {}
        for aid, mid in zip(self.agent_ids, self.model_ids):
            all_actions = {id: BATCH_DICT[id].action for id in self.agent_ids}
            all_actions[aid] = sample_pis[aid]
            all_log_pis = {id: log_pi_actions[id] for id in self.agent_ids}
            all_log_pis[aid] = log_pis[aid]

            q_s_pi = th.minimum(
                self.critics[mid](
                    [BATCH_DICT[id].obs for id in self.agent_ids],
                    th.cat(list(all_actions.values()), -1),
                    begin_mask=BATCH_DICT['global'].begin_mask),
                self.critics2[mid](
                    [BATCH_DICT[id].obs for id in self.agent_ids],
                    th.cat(list(all_actions.values()), -1),
                    begin_mask=BATCH_DICT['global'].begin_mask))  # [T, B, 1]

            _log_pis = 1.
            for _log_pi in all_log_pis.values():
                _log_pis *= _log_pi
            _log_pis += th.finfo().eps
            actor_loss[aid] = -(q_s_pi - self.alpha * _log_pis).mean()  # 1

        self.actor_oplr.optimize(sum(actor_loss.values()))

        for aid in self.agent_ids:
            summaries[aid].update({
                'LOSS/actor_loss': actor_loss[aid],
                'LOSS/critic_loss': q_loss[aid]
            })
        summaries['model'].update({
            'LOSS/actor_loss': sum(actor_loss.values()),
            'LOSS/critic_loss': sum(q_loss.values())
        })

        if self.auto_adaption:
            _log_pis = 1.
            _log_pis = 1.
            for _log_pi in log_pis.values():
                _log_pis *= _log_pi
            _log_pis += th.finfo().eps

            alpha_loss = -(
                self.alpha *
                (_log_pis + self.target_entropy).detach()).mean()  # 1

            self.alpha_oplr.optimize(alpha_loss)
            summaries['model'].update({
                'LOSS/alpha_loss':
                alpha_loss,
                'LEARNING_RATE/alpha_lr':
                self.alpha_oplr.lr
            })
        return td_errors / self.n_agents_percopy, summaries
Example #10
0
    def _train(self, BATCH):

        obs = get_first_vector(BATCH.obs)  # [T, B, S]
        obs_ = get_first_vector(BATCH.obs_)  # [T, B, S]
        _timestep = obs.shape[0]
        _batchsize = obs.shape[1]
        predicted_obs_ = self._forward_dynamic_model(obs,
                                                     BATCH.action)  # [T, B, S]
        predicted_reward = self._reward_model(obs, BATCH.action)  # [T, B, 1]
        predicted_done_dist = self._done_model(obs, BATCH.action)  # [T, B, 1]
        _obs_loss = F.mse_loss(obs_, predicted_obs_)  # todo
        _reward_loss = F.mse_loss(BATCH.reward, predicted_reward)
        _done_loss = -predicted_done_dist.log_prob(BATCH.done).mean()
        wm_loss = _obs_loss + _reward_loss + _done_loss
        self._wm_oplr.optimize(wm_loss)

        obs = th.reshape(obs, (_timestep * _batchsize, -1))  # [T*B, S]
        obs_ = th.reshape(obs_, (_timestep * _batchsize, -1))  # [T*B, S]
        actions = th.reshape(BATCH.action,
                             (_timestep * _batchsize, -1))  # [T*B, A]
        rewards = th.reshape(BATCH.reward,
                             (_timestep * _batchsize, -1))  # [T*B, 1]
        dones = th.reshape(BATCH.done,
                           (_timestep * _batchsize, -1))  # [T*B, 1]

        rollout_rewards = [rewards]
        rollout_dones = [dones]

        r_obs_ = obs_
        _r_obs = deepcopy(BATCH.obs_)
        r_done = (1. - dones)

        for _ in range(self._roll_out_horizon):
            r_obs = r_obs_
            _r_obs.vector.vector_0 = r_obs
            if self.is_continuous:
                action_target = self.actor.t(_r_obs)  # [T*B, A]
                if self.use_target_action_noise:
                    r_action = self.target_noised_action(
                        action_target)  # [T*B, A]
            else:
                target_logits = self.actor.t(_r_obs)  # [T*B, A]
                target_cate_dist = td.Categorical(logits=target_logits)
                target_pi = target_cate_dist.sample()  # [T*B,]
                r_action = F.one_hot(target_pi, self.a_dim).float()  # [T*B, A]
            r_obs_ = self._forward_dynamic_model(r_obs, r_action)  # [T*B, S]
            r_reward = self._reward_model(r_obs, r_action)  # [T*B, 1]
            r_done = r_done * (1. - self._done_model(r_obs, r_action).sample()
                               )  # [T*B, 1]

            rollout_rewards.append(r_reward)  # [H+1, T*B, 1]
            rollout_dones.append(r_done)  # [H+1, T*B, 1]

        _r_obs.vector.vector_0 = obs
        q = self.critic(_r_obs, actions)  # [T*B, 1]
        _r_obs.vector.vector_0 = r_obs_
        q_target = self.critic.t(_r_obs, r_action)  # [T*B, 1]
        dc_r = rewards
        for t in range(1, self._roll_out_horizon):
            dc_r += (self.gamma**t) * (rollout_rewards[t] * rollout_dones[t])
        dc_r += (self.gamma**self._roll_out_horizon) * rollout_dones[
            self._roll_out_horizon] * q_target  # [T*B, 1]

        td_error = dc_r - q  # [T*B, 1]
        q_loss = td_error.square().mean()  # 1
        self.critic_oplr.optimize(q_loss)

        # train actor
        if self.is_continuous:
            mu = self.actor(BATCH.obs,
                            begin_mask=BATCH.begin_mask)  # [T, B, A]
        else:
            logits = self.actor(BATCH.obs,
                                begin_mask=BATCH.begin_mask)  # [T, B, A]
            logp_all = logits.log_softmax(-1)  # [T, B, A]
            gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape)  # [T, B, A]
            _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax(
                -1)  # [T, B, A]
            _pi_true_one_hot = F.one_hot(_pi.argmax(-1),
                                         self.a_dim).float()  # [T, B, A]
            _pi_diff = (_pi_true_one_hot - _pi).detach()  # [T, B, A]
            mu = _pi_diff + _pi  # [T, B, A]
        q_actor = self.critic(BATCH.obs, mu,
                              begin_mask=BATCH.begin_mask)  # [T, B, 1]
        actor_loss = -q_actor.mean()  # 1
        self.actor_oplr.optimize(actor_loss)

        return th.ones_like(BATCH.reward), {
            'LEARNING_RATE/wm_lr': self._wm_oplr.lr,
            'LEARNING_RATE/actor_lr': self.actor_oplr.lr,
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr,
            'LOSS/wm_loss': wm_loss,
            'LOSS/actor_loss': actor_loss,
            'LOSS/critic_loss': q_loss,
            'Statistics/q_min': q.min(),
            'Statistics/q_mean': q.mean(),
            'Statistics/q_max': q.max()
        }