def forward(self, x, t, mask): #t : TxBSx2 d_j = t[:, :, 0] #TxBS batch_size, seq_length = x.size(1), x.size(0) past_influences = [] for idx in range(self.m): t_pad = torch.cat( [torch.zeros(idx + 1, batch_size).to(device), d_j])[:-(idx + 1), :][:, :, None] #TxBSx1 past_influences.append(t_pad) past_influences = torch.cat( past_influences, dim=-1) * self.alpha[None, None, :] #TxBSxm total_influence = torch.sum(past_influences, dim=-1) + self.gamma[None, :].exp() #To consider from time step 1 m = Exponential(total_influence[1:, :]) #T-1xBS ll_loss = (m.log_prob(d_j[1:, :])).sum() metric_dict = { 'true_ll': -ll_loss.detach(), "marker_acc": 0., "marker_acc_count": 1. } with torch.no_grad(): time_mse = ((d_j[1:, :] - 1. / total_influence[1:, :]) * mask[1:, :])**2. metric_dict['time_mse'] = time_mse.sum().detach().cpu().numpy() metric_dict['time_mse_count'] = mask[ 1:, :].sum().detach().cpu().numpy() return -ll_loss, metric_dict
class ExpTimeToOpen(Base): def __init__(self, rate): self.dist = Exponential(rate) def log_prob(self, times): dt = times[:, 0] return self.dist.log_prob(dt)
class ExpTimeToOpen(Base): def __init__(self, rate): self.dist = Exponential(rate) def log_prob(self, times): dt = times[:, 1] - times[:, 0] dt.apply_(lambda x: x + 24 if x < 0 else x) # find more time effective way to compute return self.dist.log_prob(dt)
class RightTruncatedExponential(torch.distributions.Distribution): def __init__(self, rate, upper): self.base = Exponential(rate) self._batch_shape = self.base.rate.size() self._upper = upper self.upper = torch.full_like(self.base.rate, upper) # normaliser self.normaliser = self.base.cdf(self.upper) self.uniform = Uniform(torch.zeros_like(self.upper), self.normaliser) def rsample(self, sample_shape=torch.Size()): # sample from truncated support (0, normaliser) # where normaliser = base.cdf(upper) u = self.uniform.rsample(sample_shape) x = self.base.icdf(u) return x def log_prob(self, value): return self.base.log_prob(value) - torch.log(self.normaliser) def cdf(self, value): return self.base.cdf(value) / self.normaliser def icdf(self, value): return self.base.icdf(value * self.normaliser) def cross_entropy(self, other): assert isinstance(other, RightTruncatedExponential) assert type(self.base) is type( other.base) and self._upper == other._upper a = torch.log(other.base.rate) - torch.log(other.normaliser) log_b = torch.log(self.base.rate) + torch.log( other.base.rate) - torch.log(self.normaliser) b = torch.exp(log_b) c = (torch.exp(-self.base.rate) * (-self.base.rate - 1) + 1) / (self.base.rate**2) return -(a - b * c) def entropy(self): return self.cross_entropy(self)