예제 #1
0
    def forward(self, flat_obs, actions=None):
        obs, taus = split_tau(flat_obs)
        if actions is not None:
            h = torch.cat((obs, actions), dim=1)
        else:
            h = obs
        batch_size = h.size()[0]
        y_binary = ptu.FloatTensor(batch_size, self.max_tau + 1)
        y_binary.zero_()
        t = taus.data.long()
        t = torch.clamp(t, min=0)
        y_binary.scatter_(1, t, 1)
        if actions is not None:
            h = torch.cat((
                obs,
                ptu.Variable(y_binary),
                actions
            ), dim=1)
        else:
            h = torch.cat((
                obs,
                ptu.Variable(y_binary),
            ), dim=1)

        for i, fc in enumerate(self.fcs):
            h = self.hidden_activation(fc(h))
        return - torch.abs(self.last_fc(h))
예제 #2
0
    def forward(
        self,
        obs,
        deterministic=False,
        return_log_prob=False,
        return_entropy=False,
        return_log_prob_of_mean=False,
    ):
        obs, taus = split_tau(obs)
        h = obs
        batch_size = h.size()[0]
        y_binary = ptu.FloatTensor(batch_size, self.max_tau + 1)
        y_binary.zero_()
        t = taus.data.long()
        t = torch.clamp(t, min=0)
        y_binary.scatter_(1, t, 1)

        h = torch.cat((
            obs,
            ptu.Variable(y_binary),
        ), dim=1)

        return super().forward(
            obs=h,
            deterministic=deterministic,
            return_log_prob=return_log_prob,
            return_entropy=return_entropy,
            return_log_prob_of_mean=return_log_prob_of_mean,
        )
예제 #3
0
    def forward(self, flat_obs, actions=None):
        obs, taus = split_tau(flat_obs)
        if actions is not None:
            h = torch.cat((obs, actions), dim=1)
        else:
            h = obs
        batch_size = taus.size()[0]
        y_binary = make_binary_tensor(taus, len(self.max_tau), batch_size)

        if actions is not None:
            h = torch.cat((
                obs,
                ptu.Variable(y_binary),
                actions
            ), dim=1)
        else:
            h = torch.cat((
                obs,
                ptu.Variable(y_binary),

            ), dim=1)

        for i, fc in enumerate(self.fcs):
            h = self.hidden_activation(fc(h))
        return - torch.abs(self.last_fc(h))
예제 #4
0
 def forward(self, flat_obs, **kwargs):
     observations, taus = split_tau(flat_obs)
     if self.obs_normalizer:
         observations = self.obs_normalizer.normalize(observations)
     flat_obs = torch.cat((
         observations,
         taus
     ))
     return self.qf.forward(flat_obs, **kwargs)
예제 #5
0
    def forward(self, flat_obs, actions=None):
        obs, taus = split_tau(flat_obs)
        if actions is not None:
            h = torch.cat((obs, actions), dim=1)
        else:
            h = obs

        batch_size = h.size()[0]
        tau_vector = Variable(torch.zeros((batch_size, self.tau_vector_len)) + taus.data)
        return - torch.abs(super().forward(h, tau_vector))
예제 #6
0
 def forward(self, flat_obs, actions=None, **kwargs):
     observations, taus = split_tau(flat_obs)
     if self.obs_normalizer:
         observations = self.obs_normalizer.normalize(observations)
     if self.action_normalizer and actions is not None:
         actions = self.action_normalizer.normalize(actions)
     flat_obs = torch.cat((
         observations,
         taus
     ))
     return self.qf.forward(flat_obs, actions=actions, **kwargs)
예제 #7
0
    def forward(
            self,
            flat_obs,
            return_preactivations=False,
    ):
        obs, taus = split_tau(flat_obs)
        batch_size = taus.size()[0]
        y_binary = make_binary_tensor(taus, len(self.max_tau), batch_size)
        h = torch.cat((
            obs,
            ptu.Variable(y_binary),
        ), dim=1)

        return super().forward(
            h,
            return_preactivations=return_preactivations
        )
예제 #8
0
    def forward(
            self,
            flat_obs,
            return_preactivations=False,
        ):
        obs, taus = split_tau(flat_obs)
        h=obs
        batch_size = h.size()[0]
        tau_vector = torch.zeros((batch_size, self.tau_vector_len)) + taus.data
        h = torch.cat((
                obs,
                ptu.Variable(tau_vector),
            ), dim=1)

        return super().forward(
            h,
            return_preactivations=return_preactivations
        )
예제 #9
0
 def forward(
         self,
         flat_obs,
         return_preactivations=False,
 ):
     obs, taus = split_tau(flat_obs)
     batch_size = obs.size()[0]
     tau_vector = Variable(torch.zeros((batch_size, self.tau_vector_len)) + taus.data)
     h = obs
     h1 = self.hidden_activation(self.first_input(h))
     h2 = self.hidden_activation(self.second_input(tau_vector))
     h = torch.cat((h1, h2), dim=1)
     for i, fc in enumerate(self.fcs):
         h = self.hidden_activation(fc(h))
     preactivations = self.last_fc(h)
     actions = self.output_activation(preactivations)
     if return_preactivations:
         return actions, preactivations
     else:
         return actions
예제 #10
0
    def forward(self,
                obs,
                deterministic=False,
                return_log_prob=False,
                return_entropy=False,
                return_log_prob_of_mean=False):
        obs, taus = split_tau(obs)
        h = obs
        batch_size = h.size()[0]
        tau_vector = torch.zeros((batch_size, self.tau_vector_len)) + taus.data
        h = torch.cat((
            obs,
            ptu.Variable(tau_vector),
        ), dim=1)

        return super().forward(
            obs=h,
            deterministic=deterministic,
            return_log_prob=return_log_prob,
            return_entropy=return_entropy,
            return_log_prob_of_mean=return_log_prob_of_mean,
        )
예제 #11
0
    def forward(
        self,
        obs,
        deterministic=False,
        return_log_prob=False,
        return_entropy=False,
        return_log_prob_of_mean=False,
    ):
        obs, taus = split_tau(obs)
        batch_size = taus.size()[0]
        y_binary = make_binary_tensor(taus, len(self.max_tau), batch_size)
        h = torch.cat((
            obs,
            ptu.Variable(y_binary),
        ), dim=1)

        return super().forward(
            obs=h,
            deterministic=deterministic,
            return_log_prob=return_log_prob,
            return_entropy=return_entropy,
            return_log_prob_of_mean=return_log_prob_of_mean,
        )
예제 #12
0
    def forward(
            self,
            flat_obs,
            return_preactivations=False
    ):
        obs, taus = split_tau(flat_obs)
        h = obs
        batch_size = h.size()[0]
        y_binary = ptu.FloatTensor(batch_size, self.max_tau + 1)
        y_binary.zero_()
        t = taus.data.long()
        t = torch.clamp(t, min=0)
        y_binary.scatter_(1, t, 1)

        h = torch.cat((
            obs,
            ptu.Variable(y_binary),
        ), dim=1)

        return super().forward(
            h,
            return_preactivations=return_preactivations,
        )
예제 #13
0
    def forward(self, flat_obs, actions=None):
        obs, taus = split_tau(flat_obs)
        if actions is not None:
            h = torch.cat((obs, action), dim=1)
        else:
            h = obs
        batch_size = h.size()[0]
        tau_vector = torch.zeros((batch_size, self.tau_vector_len)) + taus.data
        if actions is not None:
            h = torch.cat((
                obs,
                ptu.Variable(tau_vector),
                actions
            ), dim=1)
        else:
            h = torch.cat((
                obs,
                ptu.Variable(tau_vector),

            ), dim=1)

        for i, fc in enumerate(self.fcs):
            h = self.hidden_activation(fc(h))
        return - torch.abs(self.last_fc(h))
예제 #14
0
    def forward(
        self,
        obs,
        deterministic=False,
        return_log_prob=False,
        return_entropy=False,
        return_log_prob_of_mean=False,
    ):
        """
        :param obs: Observation
        :param deterministic: If True, do not sample
        :param return_log_prob: If True, return a sample and its log probability
        :param return_entropy: If True, return the true expected log
        prob. Will not need to be differentiated through, so this can be a
        number.
        :param return_log_prob_of_mean: If True, return the true expected log
        prob. Will not need to be differentiated through, so this can be a
        number.
        """
        obs, taus = split_tau(obs)
        batch_size = obs.size()[0]
        tau_vector = Variable(
            torch.zeros((batch_size, self.tau_vector_len)) + taus.data)
        h = obs
        h1 = self.hidden_activation(self.first_input(h))
        h2 = self.hidden_activation(self.second_input(tau_vector))
        h = torch.cat((h1, h2), dim=1)
        for i, fc in enumerate(self.fcs):
            h = self.hidden_activation(fc(h))
        mean = self.last_fc(h)
        if self.std is None:
            log_std = self.last_fc_log_std(h)
            log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
            std = torch.exp(log_std)
        else:
            std = self.std
            log_std = self.log_std

        log_prob = None
        entropy = None
        mean_action_log_prob = None
        pre_tanh_value = None
        if deterministic:
            action = torch.tanh(mean)
        else:
            tanh_normal = TanhNormal(mean, std)
            if return_log_prob:
                action, pre_tanh_value = tanh_normal.sample(
                    return_pretanh_value=True)
                log_prob = tanh_normal.log_prob(action,
                                                pre_tanh_value=pre_tanh_value)
                log_prob = log_prob.sum(dim=1, keepdim=True)
            else:
                action = tanh_normal.sample()

        if return_entropy:
            entropy = log_std + 0.5 + np.log(2 * np.pi) / 2
            # Because tanh is invertible, the entropy of a Gaussian and the
            # entropy of the tanh of a Gaussian is the same.
            entropy = entropy.sum(dim=1, keepdim=True)
        if return_log_prob_of_mean:
            tanh_normal = TanhNormal(mean, std)
            mean_action_log_prob = tanh_normal.log_prob(
                torch.tanh(mean),
                pre_tanh_value=mean,
            )
            mean_action_log_prob = mean_action_log_prob.sum(dim=1,
                                                            keepdim=True)
        return (
            action,
            mean,
            log_std,
            log_prob,
            entropy,
            std,
            mean_action_log_prob,
            pre_tanh_value,
        )