def nll_mtlr(phi, idx_durations, events, reduction='mean', epsilon=1e-7): """Negative log-likelihood for the MTLR parametrized model [1] [2]. This is essentially a PMF parametrization with an extra cumulative sum, as explained in [3]. Arguments: phi {torch.tensor} -- Estimates in (-inf, inf), where pmf = somefunc(phi). idx_durations {torch.tensor} -- Event times represented as indices. events {torch.tensor} -- Indicator of event (1.) or censoring (0.). Same length as 'idx_durations'. reduction {string} -- How to reduce the loss. 'none': No reduction. 'mean': Mean of tensor. 'sum: sum. Returns: torch.tensor -- The negative log-likelihood. References: [1] Chun-Nam Yu, Russell Greiner, Hsiu-Chin Lin, and Vickie Baracos. Learning patient- specific cancer survival distributions as a sequence of dependent regressors. In Advances in Neural Information Processing Systems 24, pages 1845–1853. Curran Associates, Inc., 2011. https://papers.nips.cc/paper/4210-learning-patient-specific-cancer-survival-distributions-as-a-sequence-of-dependent-regressors.pdf [2] Stephane Fotso. Deep neural networks for survival analysis based on a multi-task framework. arXiv preprint arXiv:1801.05512, 2018. https://arxiv.org/pdf/1801.05512.pdf [3] Håvard Kvamme and Ørnulf Borgan. Continuous and Discrete-Time Survival Prediction with Neural Networks. arXiv preprint arXiv:1910.06724, 2019. https://arxiv.org/pdf/1910.06724.pdf """ phi = utils.cumsum_reverse(phi, dim=1) return nll_pmf(phi, idx_durations, events, reduction, epsilon)
def predict_pmf(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False, num_workers=0): preds = self.predict(input, batch_size, False, eval_, False, to_cpu, num_workers) preds = utils.cumsum_reverse(preds, dim=1) pmf = utils.pad_col(preds).softmax(1)[:, :-1] return utils.array_or_tensor(pmf, numpy, input)
def test_cumsum_reverse_dim_1(): torch.manual_seed(1234) x = torch.randn(5, 16) res_np = x.numpy()[:, ::-1].cumsum(1)[:, ::-1] res = cumsum_reverse(x, dim=1) assert np.abs(res.numpy() - res_np).max() < 1e-6
def test_cumsum_reverse_error_dim(): x = torch.randn((5, 3)) with pytest.raises(NotImplementedError): cumsum_reverse(x, dim=0) with pytest.raises(NotImplementedError): cumsum_reverse(x, dim=2)