def forward(self, x, t, mask):
        #t : TxBSx2
        d_j = t[:, :, 0]  #TxBS
        batch_size, seq_length = x.size(1), x.size(0)
        past_influences = []
        for idx in range(self.m):
            t_pad = torch.cat(
                [torch.zeros(idx + 1, batch_size).to(device),
                 d_j])[:-(idx + 1), :][:, :, None]  #TxBSx1
            past_influences.append(t_pad)
        past_influences = torch.cat(
            past_influences, dim=-1) * self.alpha[None, None, :]  #TxBSxm
        total_influence = torch.sum(past_influences,
                                    dim=-1) + self.gamma[None, :].exp()
        #To consider from time step 1
        m = Exponential(total_influence[1:, :])  #T-1xBS
        ll_loss = (m.log_prob(d_j[1:, :])).sum()

        metric_dict = {
            'true_ll': -ll_loss.detach(),
            "marker_acc": 0.,
            "marker_acc_count": 1.
        }
        with torch.no_grad():
            time_mse = ((d_j[1:, :] - 1. / total_influence[1:, :]) *
                        mask[1:, :])**2.
            metric_dict['time_mse'] = time_mse.sum().detach().cpu().numpy()
            metric_dict['time_mse_count'] = mask[
                1:, :].sum().detach().cpu().numpy()
        return -ll_loss, metric_dict
class ExpTimeToOpen(Base):
    def __init__(self, rate):
        self.dist = Exponential(rate)

    def log_prob(self, times):
        dt = times[:, 0]
        return self.dist.log_prob(dt)
Ejemplo n.º 3
0
class ExpTimeToOpen(Base):
    def __init__(self, rate):
        self.dist = Exponential(rate)

    def log_prob(self, times):
        dt = times[:, 1] - times[:, 0]
        dt.apply_(lambda x: x + 24
                  if x < 0 else x)  # find more time effective way to compute
        return self.dist.log_prob(dt)
Ejemplo n.º 4
0
class RightTruncatedExponential(torch.distributions.Distribution):
    def __init__(self, rate, upper):
        self.base = Exponential(rate)
        self._batch_shape = self.base.rate.size()
        self._upper = upper
        self.upper = torch.full_like(self.base.rate, upper)
        # normaliser
        self.normaliser = self.base.cdf(self.upper)
        self.uniform = Uniform(torch.zeros_like(self.upper), self.normaliser)

    def rsample(self, sample_shape=torch.Size()):
        # sample from truncated support (0, normaliser)
        # where normaliser = base.cdf(upper)
        u = self.uniform.rsample(sample_shape)
        x = self.base.icdf(u)
        return x

    def log_prob(self, value):
        return self.base.log_prob(value) - torch.log(self.normaliser)

    def cdf(self, value):
        return self.base.cdf(value) / self.normaliser

    def icdf(self, value):
        return self.base.icdf(value * self.normaliser)

    def cross_entropy(self, other):
        assert isinstance(other, RightTruncatedExponential)
        assert type(self.base) is type(
            other.base) and self._upper == other._upper
        a = torch.log(other.base.rate) - torch.log(other.normaliser)
        log_b = torch.log(self.base.rate) + torch.log(
            other.base.rate) - torch.log(self.normaliser)
        b = torch.exp(log_b)
        c = (torch.exp(-self.base.rate) *
             (-self.base.rate - 1) + 1) / (self.base.rate**2)
        return -(a - b * c)

    def entropy(self):
        return self.cross_entropy(self)