Exemple #1
0
def apt_maf_loss_atomic_proposal(net, svi=False, combined_loss=False):
    """Define loss function for training with a atomic proposal. Assumes a
    uniform proposal distribution over each sample parameter and an externally
    provided set of alternatives.

    net: MAF-based conditional density net
    svi : bool
        Whether to use SVI version of the mdn or not
    """
    assert net.density == 'maf'
    assert not svi, 'SVI not supported for MAFs'

    # define symbolic variable to hold params that will be inferred
    # params : n_batch x  n_outputs
    # all_thetas : (n_batch * (n_atoms + 1)  x n_outputs
    # lprs  : (n_atoms + 1) x n_batch
    # stats :  n_batch x  n_inputs
    # x_nl  : (n_batch * (n_atoms + 1)) x n_inputs
    theta_all = tensorN(2, name='params_nl', dtype=dtype)
    x_nl = tensorN(2, name='stats_nl', dtype=dtype)
    lprs = tensorN(2, name='lprs', dtype=dtype)  # log tilde_p / p

    n_batch = tt.shape(lprs)[1]
    n_atoms = tt.shape(lprs)[0] - 1

    # compute MAF log-densities for true and other atoms
    lprobs = theano.clone(output=net.lprobs,
                          replace={
                              net.params: theta_all,
                              net.stats: x_nl
                          },
                          share_inputs=True)
    lprobs = tt.reshape(lprobs, newshape=(n_atoms + 1, n_batch), ndim=2)

    # compute nonnormalized log posterior probabilities
    atomic_ppZ = lprobs - lprs
    # compute posterior probability of true params in atomic task
    atomic_pp = atomic_ppZ[0, :].squeeze() - \
        MyLogSumExp(atomic_ppZ, axis=0).squeeze()

    # collect the extra input variables that have to be provided for each
    # training data point, and calculate the loss by averaging over samples
    trn_inputs = [theta_all, x_nl, lprs]
    if combined_loss:  # add prior loss on prior samples
        l_ml = lprobs[0, :].squeeze()  # direct posterior evaluation
        is_prior_sample = tensorN(1, name='prop_mask', dtype=dtype)
        trn_inputs.append(is_prior_sample)
        loss = -tt.mean(atomic_pp + is_prior_sample * l_ml)
    else:
        loss = -tt.mean(atomic_pp)

    return loss, trn_inputs
Exemple #2
0
def snpeb_loss(model, svi=False):
    # note that lprobs and dlprobs are the same for non-svi networks
    lprobs = model.lprobs if svi else model.dlprobs
    iws = tensorN(1, name='iws', dtype=dtype)  # importance weights
    loss = -tt.mean(iws * lprobs)
    # collect extra input variables to be provided for each training data point
    trn_inputs = [model.params, model.stats, iws]
    return loss, trn_inputs
Exemple #3
0
def test_batched_matrix_ops(dim=4, nsamples=100):
    A_pd = np.full((nsamples, dim, dim), np.nan, dtype=dtype)
    A_nonsing = np.full((nsamples, dim, dim), np.nan, dtype=dtype)
    L = np.full((nsamples, dim, dim), np.nan, dtype=dtype)
    inv = np.full((nsamples, dim, dim), np.nan, dtype=dtype)
    det = np.full(nsamples, np.nan, dtype=dtype)
    for i in range(nsamples):
        L[i] = np.tril(np.random.randn(dim, dim),
                       -1) + np.diag(1.0 + np.exp(np.random.randn(dim)))
        A_pd[i] = np.dot(L[i], L[i].T)
        L2 = np.tril(np.random.rand(dim, dim), -1) + np.diag(
            np.exp(np.random.randn(dim)))
        L3 = np.tril(np.random.rand(dim, dim), -1) + np.diag(
            np.exp(np.random.randn(dim)))
        A_nonsing[i] = np.dot(np.dot(L2, A_pd[i]), L3.T)
        inv[i] = np.linalg.inv(A_nonsing[i])
        det[i] = np.linalg.det(A_nonsing[i])

    tA = sym.tensorN(3)
    f_choleach = theano.function(inputs=[tA], outputs=sym.cholesky_each(tA))
    f_inveach = theano.function(inputs=[tA], outputs=sym.invert_each(tA))
    f_deteach = theano.function(inputs=[tA], outputs=sym.det_each(tA))

    symL = f_choleach(A_pd)
    symdet = f_deteach(A_nonsing)
    syminv = f_inveach(A_nonsing)

    assert np.allclose(symL, L, atol=1e-8)
    assert np.allclose(symdet, det, atol=1e-8)
    assert np.allclose(syminv, inv, atol=1e-8)

    try:
        f_choleach(
            A_nonsing)  # try Cholesky factorizing some non-symmetric matrices
    except Exception as e:
        assert isinstance(e, np.linalg.linalg.LinAlgError), \
            "unexpected error when trying Cholesky factorization of non-symmetric matrix"
Exemple #4
0
    def __init__(self, n_inputs=None, n_outputs=None, input_shape=None,
                 n_bypass=0,
                 density='mog',
                 n_hiddens=(10, 10), impute_missing=True, seed=None,
                 n_filters=(), filter_sizes=3, pool_sizes=2,
                 n_rnn=0,
                 **density_opts):

        """Initialize a mixture density network with custom layers

        Parameters
        ----------
        n_inputs : int
            Total input dimensionality (data/summary stats)
        n_outputs : int
            Dimensionality of output (simulator parameters)
        input_shape : tuple
            Size to which data are reshaped before CNN or RNN
        n_bypass : int
            Number of elements at end of input which bypass CNN or RNN
        density : string
            Type of density condition on the network, can be 'mog' or 'maf'
        n_components : int
            Number of components of the mixture density
        n_filters : list of ints
            Number of filters  per convolutional layer
        n_hiddens : list of ints
            Number of hidden units per fully connected layer
        n_rnn : None or int
            Number of RNN units
        impute_missing : bool
            If set to True, learns replacement value for NaNs, otherwise those
            inputs are set to zero
        seed : int or None
            If provided, random number generator will be seeded
        density_opts : dict
            Options for the density estimator
        """
        if n_rnn > 0 and len(n_filters) > 0:
            raise NotImplementedError
        assert isint(n_inputs) and isint(n_outputs)\
            and n_inputs > 0 and n_outputs > 0

        self.density = density.lower()
        self.impute_missing = impute_missing
        self.n_hiddens = list(n_hiddens)
        self.n_outputs, self.n_inputs = n_outputs, n_inputs
        self.n_bypass = n_bypass

        self.n_rnn = n_rnn

        self.n_filters, self.filter_sizes, self.pool_sizes, n_cnn = \
            list(n_filters), filter_sizes, pool_sizes, len(n_filters)
        if type(self.filter_sizes) is int:
            self.filter_sizes = [self.filter_sizes for _ in range(n_cnn)]
        else:
            assert len(self.filter_sizes) >= n_cnn
        if type(self.pool_sizes) is int:
            self.pool_sizes = [self.pool_sizes for _ in range(n_cnn)]
        else:
            assert len(self.pool_sizes) >= n_cnn

        self.iws = tt.vector('iws', dtype=dtype)

        self.seed = seed
        if seed is not None:
            self.rng = np.random.RandomState(seed=seed)
        else:
            self.rng = np.random.RandomState()
        lasagne.random.set_rng(self.rng)

        self.input_shape = (n_inputs,) if input_shape is None else input_shape
        assert np.prod(self.input_shape) + self.n_bypass == self.n_inputs
        assert 1 <= len(self.input_shape) <= 3

        # params: output placeholder (batch, self.n_outputs)
        self.params = tensorN(2, name='params', dtype=dtype)

        # stats : input placeholder, (batch, self.n_inputs)
        self.stats = tensorN(2, name='stats', dtype=dtype)

        # compose layers
        self.layer = collections.OrderedDict()

        # input layer, None indicates batch size not fixed at compile time
        self.layer['input'] = ll.InputLayer(
            (None, self.n_inputs), input_var=self.stats)

        # learn replacement values
        if self.impute_missing:
            self.layer['missing'] = \
                dl.ImputeMissingLayer(last(self.layer),
                                      n_inputs=(self.n_inputs,))
        else:
            self.layer['missing'] = \
                dl.ReplaceMissingLayer(last(self.layer),
                                       n_inputs=(self.n_inputs,))

        if self.n_bypass > 0 and (self.n_rnn > 0 or n_cnn > 0):
            last_layer = last(self.layer)
            bypass_slice = slice(self.n_inputs - self.n_bypass, self.n_inputs)
            direct_slice = slice(0, self.n_inputs - self.n_bypass)
            self.layer['bypass'] = ll.SliceLayer(last_layer, bypass_slice)
            self.layer['direct'] = ll.SliceLayer(last_layer, direct_slice)

        # reshape inputs prior to RNN or CNN step
        if self.n_rnn > 0 or n_cnn > 0:

            if len(n_filters) > 0 and len(self.input_shape) == 2:  # 1 channel
                rs = (-1, 1, *self.input_shape)
            else:
                if self.n_rnn > 0:
                    assert len(self.input_shape) == 2  # time, dim
                else:
                    assert len(self.input_shape) == 3  # channel, row, col
                rs = (-1, *self.input_shape)

            # last layer is 'missing' or 'direct'
            self.layer['reshape'] = ll.ReshapeLayer(last(self.layer), rs)

        # recurrent neural net, input: (batch, sequence_length, num_inputs)
        if self.n_rnn > 0:
            self.layer['rnn'] = ll.GRULayer(last(self.layer), n_rnn,
                                            only_return_final=True)

        # convolutional net, input: (batch, channels, rows, columns)
        if n_cnn > 0:
            for l in range(n_cnn):  # add layers
                if self.pool_sizes[l] == 1:
                    padding = (self.filter_sizes[l] - 1) // 2
                else:
                    padding = 0
                self.layer['conv_' + str(l + 1)] = ll.Conv2DLayer(
                    name='c' + str(l + 1),
                    incoming=last(self.layer),
                    num_filters=self.n_filters[l],
                    filter_size=self.filter_sizes[l],
                    stride=(1, 1),
                    pad=padding,
                    untie_biases=False,
                    W=lasagne.init.GlorotUniform(),
                    b=lasagne.init.Constant(0.),
                    nonlinearity=lnl.rectify,
                    flip_filters=True,
                    convolution=tt.nnet.conv2d)

                if self.pool_sizes[l] > 1:
                    self.layer['pool_' + str(l + 1)] = ll.MaxPool2DLayer(
                        name='p' + str(l + 1),
                        incoming=last(self.layer),
                        pool_size=self.pool_sizes[l],
                        stride=None,
                        ignore_border=True)

        # flatten
        self.layer['flatten'] = ll.FlattenLayer(
            incoming=last(self.layer),
            outdim=2)

        # incorporate bypass inputs
        if self.n_bypass > 0 and (self.n_rnn > 0 or n_cnn > 0):
            self.layer['bypass_merge'] = lasagne.layers.ConcatLayer(
                [self.layer['bypass'], last(self.layer)], axis=1)

        if self.density == 'mog':
            self.init_mdn(**density_opts)
        elif self.density == 'maf':
            self.init_maf(**density_opts)
        else:
            raise NotImplementedError

        self.compile_funs()  # theano functions
Exemple #5
0
def apt_loss_MoG_proposal(mdn, prior, n_proposal_components=None, svi=False):
    """Define loss function for training with a MoG proposal, allowing
    the proposal distribution to be different for each sample. The proposal
    means, precisions and weights are passed along with the stats and params.

    This function computes a symbolic expression but does not actually compile
    or calculate that expression.

    This loss does not include any regularization or prior terms, which must be
    added separately before compiling.

    Parameters
    ----------
    mdn : NeuralNet
        Mixture density network.
    svi : bool
        Whether to use SVI version of the mdn or not

    Returns
    -------
    loss : theano scalar
        Loss function
    trn_inputs : list
        Tensors to be provided to the loss function during training
    """
    assert mdn.density == 'mog'
    uniform_prior = isinstance(prior, dd.Uniform)
    if not uniform_prior and not isinstance(prior, dd.Gaussian):
        raise NotImplemented  # prior must be Gaussian or uniform
    ncprop = mdn.n_components if n_proposal_components is None \
        else n_proposal_components  # default to component count of posterior

    nbatch = mdn.params.shape[0]

    # a mixture weights, ms means, P=U^T U precisions, QFs are tensorQF(P, m)
    a, ms, Us, ldetUs, Ps, ldetPs, Pms, QFs = \
        mdn.get_mog_tensors(return_extras=True, svi=svi)
    # convert mixture vars from lists of tensors to single tensors, of sizes:
    las = tt.log(a)
    Ps = tt.stack(Ps, axis=3).dimshuffle(0, 3, 1, 2)
    Pms = tt.stack(Pms, axis=2).dimshuffle(0, 2, 1)
    ldetPs = tt.stack(ldetPs, axis=1)
    QFs = tt.stack(QFs, axis=1)
    # as: (batch, mdn.n_components)
    # Ps: (batch, mdn.n_components, n_outputs, n_outputs)
    # Pm: (batch, mdn.n_components, n_outputs)
    # ldetPs: (batch, mdn.n_components)
    # QFs: (batch, mdn.n_components)

    # Define symbolic variables, that hold for each sample's MoG proposal:
    # precisions times means (batch, ncprop, n_outputs)
    # precisions (batch, ncprop, n_outputs, n_outputs)
    # log determinants of precisions (batch, ncprop)
    # log mixture weights (batch, ncprop)
    # quadratic forms QF = m^T P m (batch, ncprop)
    prop_Pms = tensorN(3, name='prop_Pms', dtype=dtype)
    prop_Ps = tensorN(4, name='prop_Ps', dtype=dtype)
    prop_ldetPs = tensorN(2, name='prop_ldetPs', dtype=dtype)
    prop_las = tensorN(2, name='prop_las', dtype=dtype)
    prop_QFs = tensorN(2, name='prop_QFs', dtype=dtype)

    # calculate corrections to precisions (P_0s) and precisions * means (Pm_0s)
    P_0s = prop_Ps
    Pm_0s = prop_Pms
    if not uniform_prior:  # Gaussian prior
        P_0s = P_0s - prior.P
        Pm_0s = Pm_0s - prior.Pm

    # To calculate the proposal posterior, we multiply all mixture component
    # pdfs from the true posterior by those from the proposal. The resulting
    # new mixture is ordered such that all product terms involving the first
    # component of the true posterior appear first. The shape of pp_Ps is
    # (batch, mdn.n_components, ncprop, n_outputs, n_outputs)
    pp_Ps = Ps.dimshuffle(0, 1, 'x', 2, 3) + P_0s.dimshuffle(0, 'x', 1, 2, 3)
    pp_Ss = invert_each(pp_Ps)  # covariances of proposal posterior components
    pp_ldetPs = det_each(pp_Ps, log=True)  # log determinants
    # precision times mean for each proposal posterior component:
    pp_Pms = Pms.dimshuffle(0, 1, 'x', 2) + Pm_0s.dimshuffle(0, 'x', 1, 2)
    # mean of proposal posterior components:
    pp_ms = (pp_Ss * pp_Pms.dimshuffle(0, 1, 2, 'x', 3)).sum(axis=4)
    # quadratic form defined by each pp_P evaluated at each pp_m
    pp_QFs = (pp_Pms * pp_ms).sum(axis=3)

    # normalization constants for integrals of Gaussian product-quotients
    # (for Gaussian proposals) or Gaussian products (for uniform priors)
    # Note we drop a "constant" (for each combination of sample, proposal
    # component posterior component) term of
    #
    # 0.5 * (tensorQF(prior.P, prior.m) - prior_ldetP)
    #
    # since we're going to normalize the pp mixture coefficients sum to 1
    pp_lZs = 0.5 * ((ldetPs - QFs).dimshuffle(0, 1, 'x') +
                    (prop_ldetPs - prop_QFs).dimshuffle(0, 'x', 1) -
                    (pp_ldetPs - pp_QFs))

    # calculate non-normalized log mixture coefficients of the proposal
    # posterior by adding log posterior weights a to normalization coefficients
    # Z. These do not yet sum to 1 in the linear domain
    pp_las_nonnormed = \
        las.dimshuffle(0, 1, 'x') + prop_las.dimshuffle(0, 'x', 1) + pp_lZs

    # reshape tensors describing proposal posterior components so that there's
    # only one dimension that ranges over components
    ncpp = ncprop * mdn.n_components  # number of proposal posterior components
    pp_las_nonnormed = pp_las_nonnormed.reshape((nbatch, ncpp))
    pp_ldetPs = pp_ldetPs.reshape((nbatch, ncpp))
    pp_ms = pp_ms.reshape((nbatch, ncpp, mdn.n_outputs))
    pp_Ps = pp_Ps.reshape((nbatch, ncpp, mdn.n_outputs, mdn.n_outputs))

    # normalize log mixture weights so they sum to 1 in the linear domain
    pp_las = pp_las_nonnormed - MyLogSumExp(pp_las_nonnormed, axis=1)

    mog_LL_inputs = \
        [(pp_ms[:, i, :], pp_Ps[:, i, :, :], pp_ldetPs[:, i])
         for i in range(ncpp)]  # list (over comps) of tuples (over vars)
    # 2 tensor inputs, lists (over comps) of tensors:
    mog_LL_inputs = [mdn.params, pp_las, *zip(*mog_LL_inputs)]

    loss = -tt.mean(mog_LL(*mog_LL_inputs))

    # collect extra input variables to be provided for each training data point
    trn_inputs = [
        mdn.params, mdn.stats, prop_Pms, prop_Ps, prop_ldetPs, prop_las,
        prop_QFs
    ]

    return loss, trn_inputs
Exemple #6
0
def apt_mdn_loss_atomic_proposal(mdn, svi=False, combined_loss=False):
    """Define loss function for training with a atomic proposal. Assumes a
    uniform proposal distribution over each sample parameter and an externally
    provided set of alternatives.

    mdn: NeuralNet
        Mixture density network.
    svi : bool
        Whether to use SVI version of the mdn or not
    """
    assert mdn.density == 'mog'

    # a is mixture weights, ms are means, U^T U are precisions
    a, ms, Us, ldetUs = mdn.get_mog_tensors(svi=svi)

    # define symbolic variable to hold params that will be inferred
    # theta_all : (n_batch * (n_atoms + 1)  x n_outputs
    # lprs  : n_batch x (n_atoms+1)
    theta_all = tensorN(3, name='params_nl',
                        dtype=dtype)  # true (row 1), atoms
    lprs = tensorN(2, name='lprs', dtype=dtype)  # log tilde_p / p

    # calculate Mahalanobis distances distances wrt U'U for every theta,x pair
    # diffs : [ n_batch x (n_atoms+1) x n_outputs for each component ]
    # Ms    : [ n_batch x (n_atoms+1)             for each component ]
    # Ms[k][n,i] = (theta[i] - m[k][n])' U[k][n]' U[k][n] (theta[i] - m[k][n])
    dthetas = [theta_all - m.dimshuffle([0, 'x', 1])
               for m in ms]  # theta[i] - m[k][n]
    Ms = [
        tt.sum(tt.sum(dtheta.dimshuffle([0, 1, 'x', 2]) *
                      U.dimshuffle([0, 'x', 1, 2]),
                      axis=3)**2,
               axis=2) for dtheta, U in zip(dthetas, Us)
    ]

    # compute (unnormalized) log-densities, weighted by log prior ratios
    Ms = [-0.5 * M + lprs for M in Ms]

    # compute per-component log-densities and log-normalizers
    lprobs_comps = [M[:, 0] + ldetU for M, ldetU in zip(Ms, ldetUs)]
    lZ_comps = [
        MyLogSumExp(M, axis=1).squeeze() + ldetU
        for M, ldetU in zip(Ms, ldetUs)
    ]  # sum over all proposal thetas

    # compute overall log-densities and log-normalizers across components
    lq = MyLogSumExp(tt.stack(lprobs_comps, axis=1) + tt.log(a), axis=1)
    lZ = MyLogSumExp(tt.stack(lZ_comps, axis=1) + tt.log(a), axis=1)

    lprobs = lq.squeeze() - lZ.squeeze()

    # collect the extra input variables that have to be provided for each
    # training data point
    trn_inputs = [theta_all, mdn.stats, lprs]
    if combined_loss:  # add prior loss on prior samples
        l_ml = lq.squeeze()  # direct posterior evalution
        is_prior_sample = tensorN(1, name='prop_mask', dtype=dtype)
        trn_inputs.append(is_prior_sample)
        loss = -tt.mean(lprobs + is_prior_sample * l_ml)
    else:
        loss = -tt.mean(lprobs)  # average over samples

    return loss, trn_inputs
Exemple #7
0
def apt_loss_gaussian_proposal(mdn, prior, svi=False):
    """Define loss function for training with a Gaussian proposal, allowing
    the proposal distribution to be different for each sample. The proposal
    mean and precision are passed along with the stats and params.

    This function computes a symbolic expression but does not actually compile
    or calculate that expression.

    This loss does not include any regularization or prior terms, which must be
    added separately before compiling.

    Parameters
    ----------
    mdn : NeuralNet
        Mixture density network.
    svi : bool
        Whether to use SVI version of the mdn or not

    Returns
    -------
    loss : theano scalar
        Loss function
    trn_inputs : list
        Tensors to be provided to the loss function during training
    """
    assert mdn.density == 'mog'
    uniform_prior = isinstance(prior, dd.Uniform)
    if not uniform_prior and not isinstance(prior, dd.Gaussian):
        raise NotImplemented  # prior must be Gaussian or uniform

    # a mixture weights, ms means, P=U^T U precisions, QFs are tensorQF(P, m)
    a, ms, Us, ldetUs, Ps, ldetPs, Pms, QFs = \
        mdn.get_mog_tensors(return_extras=True, svi=svi)

    # define symbolic variables to hold for each sample's Gaussian proposal:
    # means (batch, n_outputs)
    # precisions (batch, n_outputs, n_outputs)
    prop_m = tensorN(2, name='prop_m', dtype=dtype)
    prop_P = tensorN(3, name='prop_P', dtype=dtype)

    # calculate corrections to precision (P_0) and precision * mean (Pm_0)
    P_0 = prop_P
    Pm_0 = tt.sum(prop_P * prop_m.dimshuffle(0, 'x', 1), axis=2)
    if not uniform_prior:  # Gaussian prior
        P_0 = P_0 - prior.P
        Pm_0 = Pm_0 - prior.Pm

    # precisions of proposal posterior components:
    pp_Ps = [P + P_0 for P in Ps]
    # covariances of proposal posterior components:
    pp_Ss = [invert_each(P) for P in pp_Ps]
    # log determinant of each proposal posterior component's precision:
    pp_ldetPs = [det_each(P, log=True) for P in pp_Ps]
    # precision times mean for each proposal posterior component:
    pp_Pms = [Pm + Pm_0 for Pm in Pms]
    # mean of proposal posterior components:
    pp_ms = [tt.batched_dot(S, Pm) for S, Pm in zip(pp_Ss, pp_Pms)]
    # quadratic form defined by each pp_P evaluated at each pp_m
    pp_QFs = [tt.sum(m * P, axis=1) for m, P in zip(pp_ms, pp_Pms)]

    # normalization constants for integrals of Gaussian product-quotients
    # (for Gaussian proposals) or Gaussian products (for uniform priors)
    # Note we drop a "constant" (for each sample, w.r.t trained params) term of
    #
    # 0.5 * (prop_ldetP - prior_ldetP +
    # tensorQF(prior.P, prior.m) - tensorQF(prop_P, prop_m))
    #
    # since we're going to normalize the pp mixture coefficients sum to 1
    pp_lZs = [
        0.5 * (ldetP - pp_ldetP - QF + pp_QF)
        for ldetP, pp_ldetP, QF, pp_QF in zip(ldetPs, pp_ldetPs, QFs, pp_QFs)
    ]

    # calculate log mixture coefficients of proposal posterior in two steps:
    # 1) add log posterior weights a to normalization coefficients Z
    # 2) normalize to sum to 1 in the linear domain, but stay in the log domain
    pp_las = tt.stack(pp_lZs, axis=1) + tt.log(a)
    pp_las = pp_las - MyLogSumExp(pp_las, axis=1)

    loss = -tt.mean(mog_LL(mdn.params, pp_las, pp_ms, pp_Ps, pp_ldetPs))

    # collect extra input variables to be provided for each training data point
    trn_inputs = [mdn.params, mdn.stats, prop_m, prop_P]

    return loss, trn_inputs
Exemple #8
0
def apt_loss_gaussian_proposal(mdn,
                               prior,
                               svi=False,
                               add_prior_precision=True,
                               Ptol=1e-7):
    """Define loss function for training with a Gaussian proposal, allowing
    the proposal distribution to be different for each sample. The proposal
    mean and precision are passed along with the stats and params.

    This function computes a symbolic expression but does not actually compile
    or calculate that expression.

    This loss does not include any regularization or prior terms, which must be
    added separately before compiling.

    Parameters
    ----------
    mdn : NeuralNet
        Mixture density network.
    svi : bool
        Whether to use SVI version of the mdn or not

    Returns
    -------
    loss : theano scalar
        Loss function
    trn_inputs : list
        Tensors to be provided to the loss function during training
    prior: delfi distribution
        Prior distribution on parameters
    """
    assert mdn.density == 'mog'
    uniform_prior = isinstance(prior, dd.Uniform)
    if not uniform_prior and not isinstance(prior, dd.Gaussian):
        raise NotImplemented  # prior must be Gaussian or uniform

    # a mixture weights, ms means, P=U^T U precisions, QFs are tensorQF(P, m)
    a, ms, Us, ldetUs, Ps, ldetPs, Pms, QFs = \
        mdn.get_mog_tensors(return_extras=True, svi=svi)

    # define symbolic variables to hold for each sample's Gaussian proposal:
    # means (batch, n_outputs)
    # precisions (batch, n_outputs, n_outputs)
    prop_m = tensorN(2, name='prop_m', dtype=dtype)
    prop_P = tensorN(3, name='prop_P', dtype=dtype)

    # calculate corrections to precision (P_0) and precision * mean (Pm_0)
    P_0 = prop_P
    Pm_0 = tt.sum(prop_P * prop_m.dimshuffle(0, 'x', 1), axis=2)
    if not uniform_prior and not add_prior_precision:  # Gaussian prior
        P_0 = P_0 - prior.P
        Pm_0 = Pm_0 - prior.Pm

    # precisions of proposal posterior components (before numerical conditioning step)
    pp_Ps = [P + P_0 for P in Ps]

    # get square roots of diagonal entries of posterior proposal precision components, which are equal to the L2 norms
    # of the Cholesky factor columns for the same matrix. we'll use these to improve the numerical conditioning of pp_Ps
    ds = [
        tt.sqrt(tt.sum(pp_P * np.eye(mdn.n_outputs), axis=2)) for pp_P in pp_Ps
    ]
    # normalize the estimate of each true posterior component according to the corresponding elements of d:
    # first normalize the Cholesky factor of the true posterior component estimate...
    Us_normed = [U / d.dimshuffle(0, 'x', 1) for U, d in zip(Us, ds)]
    # then normalize the propsal. the resulting list is the same proposal, differently normalized for each component of
    # the true posterior
    P_0s_normed = [
        P_0 / (d.dimshuffle(0, 'x', 1) * d.dimshuffle(0, 1, 'x')) for d in ds
    ]
    pp_Ps_normed = [
        tt.batched_dot(U_normed.dimshuffle(0, 2, 1), U_normed) + P_0_normed +
        np.eye(mdn.n_outputs) * Ptol
        for U_normed, P_0_normed in zip(Us_normed, P_0s_normed)
    ]
    # lower Cholesky factors for normalized precisions of proposal posterior components
    pp_Ls_normed = [cholesky_each(pp_P_normed) for pp_P_normed in pp_Ps_normed]
    # log determinants of lower Cholesky factors for normalized precisions of proposal posterior components
    pp_ldetLs_normed = [
        tt.sum(tt.log(tt.sum(pp_L_normed * np.eye(mdn.n_outputs), axis=2)),
               axis=1) for pp_L_normed in pp_Ls_normed
    ]
    # precisions of proposal posterior components (now well-conditioned)
    pp_Ps = [
        d.dimshuffle(0, 1, 'x') * pp_P_normed * d.dimshuffle(0, 'x', 1)
        for pp_P_normed, d in zip(pp_Ps_normed, ds)
    ]
    # log determinants of proposal posterior precisions
    pp_ldetPs = [
        2.0 * (tt.sum(tt.log(d), axis=1) + pp_ldetL_normed)
        for d, pp_ldetL_normed in zip(ds, pp_ldetLs_normed)
    ]

    # covariances of proposal posterior components:
    pp_Ss = [invert_each(P) for P in pp_Ps]
    # precision times mean for each proposal posterior component:
    pp_Pms = [Pm + Pm_0 for Pm in Pms]
    # mean of proposal posterior components:
    pp_ms = [tt.batched_dot(S, Pm) for S, Pm in zip(pp_Ss, pp_Pms)]
    # quadratic form defined by each pp_P evaluated at each pp_m
    pp_QFs = [tt.sum(m * Pm, axis=1) for m, Pm in zip(pp_ms, pp_Pms)]

    # normalization constants for integrals of Gaussian product-quotients
    # (for Gaussian proposals) or Gaussian products (for uniform priors)
    # Note we drop a "constant" (for each sample, w.r.t trained params) term of
    #
    # 0.5 * (prop_ldetP - prior_ldetP + tensorQF(prior.P, prior.m) - tensorQF(prop_P, prop_m))
    #
    # since we're going to normalize the pp mixture coefficients sum to 1
    pp_lZs = [
        0.5 * (ldetP - pp_ldetP - QF + pp_QF)
        for ldetP, pp_ldetP, QF, pp_QF in zip(ldetPs, pp_ldetPs, QFs, pp_QFs)
    ]

    # calculate log mixture coefficients of proposal posterior in two steps:
    # 1) add log posterior weights a to normalization coefficients Z
    # 2) normalize to sum to 1 in the linear domain, but stay in the log domain
    pp_las = tt.stack(pp_lZs, axis=1) + tt.log(a)
    pp_las = pp_las - MyLogSumExp(pp_las, axis=1)

    loss = -tt.mean(mog_LL(mdn.params, pp_las, pp_ms, pp_Ps, pp_ldetPs))

    # collect extra input variables to be provided for each training data point
    trn_inputs = [mdn.params, mdn.stats, prop_m, prop_P]

    return loss, trn_inputs