def genprior(varname, prior_desc): assert type(varname) == str assert type(prior_desc) == list assert all( type(p) == tuple and type(p[0]) == Family and type(p[1]) == int for p in prior_desc) code = [] # Sample each segment of a coefficient vector. for i, (prior, length) in enumerate(prior_desc): code.append( sample('{}_{}'.format(varname, i), gendist(prior, args(prior), [length], False))) if len(prior_desc) == 0: code.append('{} = torch.tensor([])'.format(varname)) elif len(prior_desc) == 1: code.append(f'{varname} = {varname}_0') else: # Concatenate the segments to produce the final vector. varname_coefs = ", ".join(f'{varname}_{i}' for i in range(len(prior_desc))) code.append(f'{varname} = torch.cat([{varname_coefs}])') return code
def genprior(varname, prior_desc): assert type(varname) == str assert type(prior_desc) == list assert all( type(p) == tuple and type(p[0]) == Family and type(p[1]) == int for p in prior_desc) code = [] # Sample each segment of a coefficient vector. for i, (prior, length) in enumerate(prior_desc): code.append( sample('{}_{}'.format(varname, i), gendist(prior, args(prior), [length], False))) # TODO: Optimisation -- avoid `torch.concat` when only sample is # drawn. (Binding the sampled value directly to `varname`.) if len(prior_desc) > 0: # Concatenate the segments to produce the final vector. code.append('{} = torch.cat([{}])'.format( varname, ', '.join('{}_{}'.format(varname, i) for i in range(len(prior_desc))))) else: code.append('{} = torch.tensor([])'.format(varname)) return code
def genmodel(model): assert type(model) == ModelDesc num_groups = len(model.groups) body = [] body.append( 'assert mode == "full" or mode == "prior_and_mu" or mode == "prior_only"' ) body.append('assert (subsample is None) == (dfN is None)' ) # Expect both or neither. body.append('assert type(X) == torch.Tensor') body.append('N = X.shape[0]') body.append('if dfN is None:') body.append(indent('dfN = N')) body.append('else:') body.append(indent('assert len(subsample) == N')) # The number of columns in the design matrix. M = len(model.population.coefs) body.append('M = {}'.format(M)) body.append('assert X.shape == (N, M)') body.append('') # Population level # -------------------------------------------------- # Prior over b. (The population level coefficients.) body.extend(genprior('b', contig(model.population.priors))) body.append('assert b.shape == (M,)') # Group level # -------------------------------------------------- mu_code = [] for i, group in enumerate(model.groups): grp_code, grp_mu_code = gengroup(i, group) body.extend(grp_code) mu_code.extend(grp_mu_code) # Compute mu. body.append('') body.append('if mode == "prior_only":') body.append(indent('mu = None')) body.append('else:') body.append(indent('mu = torch.mv(X, b)')) body.extend(indent(line) for line in mu_code) body.append('') # Response # -------------------------------------------------- # Sample from priors over the response distribution parameters # that aren't predicted from the data. for param, param_prior in zip(model.response.nonlocparams, model.response.priors): body.append( sample(param.name, gendist(param_prior, args(param_prior), [1], False))) body.append('if mode == "full":') body.append(indent('with pyro.plate("obs", dfN, subsample=subsample):')) body.append(indent(indent(sample('y', gen_response_dist(model), 'y_obs')))) # Values of interest that are not generated directly by sample # statements (such as the `b` vector) are returned from the model # so that they can be retrieved from the execution trace later. returned_params = (['mu', 'b'] + ['sd_{}'.format(i) for i in range(num_groups)] + ['r_{}'.format(i) for i in range(num_groups)]) retval = '{{{}}}'.format(', '.join('\'{}\': {}'.format(p, p) for p in returned_params)) body.append('') body.append('return {}'.format(retval)) params = (['X'] + ['Z_{}'.format(i) for i in range(num_groups)] + ['J_{}'.format(i) for i in range(num_groups)] + ['y_obs=None', 'dfN=None', 'subsample=None', 'mode="full"']) return '\n'.join(method('model', params, body))
def gengroup(i, group): assert type(i) == int # A unique int assigned to each group. assert type(group) == Group cmt = comment('Group {}: factor={}'.format(i, ':'.join(group.columns))) code = ['', cmt] mu_code = [cmt] # The number of coefficients per level. M_i = len(group.coefs) # The number of levels. N_i = len(group.levels) # This follows the names used in brms. code.append('M_{} = {} # Number of coeffs'.format(i, M_i)) code.append('N_{} = {} # Number of levels'.format(i, N_i)) code.append('assert type(Z_{}) == torch.Tensor'.format(i)) code.append('assert Z_{}.shape == (N, M_{}) # N x {}'.format(i, i, M_i)) code.append('assert type(J_{}) == torch.Tensor'.format(i)) code.append('assert J_{}.shape == (N,)'.format(i)) # Prior over coefficient scales. code.extend(genprior('sd_{}'.format(i), contig(group.sd_priors))) code.append('assert sd_{}.shape == (M_{},) # {}'.format(i, i, M_i)) # Prior over a matrix of unscaled/uncorrelated coefficients. This # is similar to the brms generated Stan code. An alternative would # be to pass `torch.mm(torch.diag(sd_{}), L_{})` as the # `scale_tril` argument of a `MultivariateNormal`. Is there any # significant different between these two approaches? code.append( sample('z_{}'.format(i), gendist(Normal, [0., 1.], [M_i, N_i], batch=False))) code.append('assert z_{}.shape == (M_{}, N_{}) # {} x {}'.format( i, i, i, M_i, N_i)) if group.corr_prior: # Model correlations between the coefficients. # This is guaranteed by the way the prior tree is built. assert M_i > 1 # Prior over correlations. prior = group.corr_prior assert len(args(prior)) == 1 code.append( sample('L_{}'.format(i), lkj_corr_cholesky(M_i, shape=args(prior)[0]))) code.append('assert L_{}.shape == (M_{}, M_{}) # {} x {}'.format( i, i, i, M_i, M_i)) # Compute the final (scaled, correlated) coefficients. # When L_i is the identity matrix (representing no # correlation) the following computation of r_i is equivalent # to that for r_i in the case where we don't model # correlations between coefficients. i.e. the other branch of # this conditional. code.append( 'r_{} = torch.mm(torch.mm(torch.diag(sd_{}), L_{}), z_{}).transpose(0, 1)' .format(i, i, i, i)) else: # Compute the final (scaled) coefficients. code.append( 'r_{} = (z_{} * sd_{}.unsqueeze(1)).transpose(0, 1)'.format( i, i, i)) code.append('assert r_{}.shape == (N_{}, M_{}) # {} x {}'.format( i, i, i, N_i, M_i)) # XXX: This allocates a large intermediate tensor `r_1[J_1]`. # An alternative might be to iterate over N_i levels and use # scatter_add to add that level's contribution to the rows # in mu that belong to that level. mu_code.append(f'mu = mu + torch.sum(Z_{i} * r_{i}[J_{i}], 1)') return code, mu_code
def gengroup(i, group): assert type(i) == int # A unique int assigned to each group. assert type(group) == Group cmt = comment('Group {}: factor={}'.format(i, ':'.join(group.columns))) code = ['', cmt] mu_code = [cmt] # The number of coefficients per level. M_i = len(group.coefs) # The number of levels. N_i = len(group.levels) # This follows the names used in brms. code.append('M_{} = {} # Number of coeffs'.format(i, M_i)) code.append('N_{} = {} # Number of levels'.format(i, N_i)) code.append('assert type(Z_{}) == torch.Tensor'.format(i)) code.append('assert Z_{}.shape == (N, M_{}) # N x {}'.format(i, i, M_i)) code.append('assert type(J_{}) == torch.Tensor'.format(i)) code.append('assert J_{}.shape == (N,)'.format(i)) # Prior over coefficient scales. code.extend(genprior('sd_{}'.format(i), contig(group.sd_priors))) code.append('assert sd_{}.shape == (M_{},) # {}'.format(i, i, M_i)) # Prior over a matrix of unscaled/uncorrelated coefficients. This # is similar to the brms generated Stan code. An alternative would # be to pass `torch.mm(torch.diag(sd_{}), L_{})` as the # `scale_tril` argument of a `MultivariateNormal`. Is there any # significant different between these two approaches? code.append( sample('z_{}'.format(i), gendist(Normal, [0., 1.], [M_i, N_i], batch=False))) code.append('assert z_{}.shape == (M_{}, N_{}) # {} x {}'.format( i, i, i, M_i, N_i)) if group.corr_prior: # Model correlations between the coefficients. # This is guaranteed by the way the prior tree is built. assert M_i > 1 # Prior over correlations. prior = group.corr_prior assert len(args(prior)) == 1 code.append( sample('L_{}'.format(i), lkj_corr_cholesky(M_i, shape=args(prior)[0]))) code.append('assert L_{}.shape == (M_{}, M_{}) # {} x {}'.format( i, i, i, M_i, M_i)) # Compute the final (scaled, correlated) coefficients. # When L_i is the identity matrix (representing no # correlation) the following computation of r_i is equivalent # to that for r_i in the case where we don't model # correlations between coefficients. i.e. the other branch of # this conditional. code.append( 'r_{} = torch.mm(torch.mm(torch.diag(sd_{}), L_{}), z_{}).transpose(0, 1)' .format(i, i, i, i)) else: # Compute the final (scaled) coefficients. code.append( 'r_{} = (z_{} * sd_{}.unsqueeze(1)).transpose(0, 1)'.format( i, i, i)) code.append('assert r_{}.shape == (N_{}, M_{}) # {} x {}'.format( i, i, i, N_i, M_i)) # The following has a similar structure to the code generated by # brms (in order to ease the comparison of generated code), though # it's not clear that this will have optimal performance in # PyTorch. # TODO: One alternative is the following (rather than looping over # each coefficient): # mu = mu + torch.sum(Z_1 * r_1[J_1], 1) # This is vectorised over N and M, but allocates a large # intermediate tensor `r_1[J_1]`. (Though I don't think this is # worse than the current implementation.) Can this be avoided? (I # guess einsum doesn't help because we'd have nested indices?) for j in range(M_i): mu_code.append('r_{}_{} = r_{}[:, {}]'.format(i, j + 1, i, j)) for j in range(M_i): mu_code.append('Z_{}_{} = Z_{}[:, {}]'.format(i, j + 1, i, j)) for j in range(M_i): mu_code.append('mu = mu + r_{}_{}[J_{}] * Z_{}_{}'.format( i, j + 1, i, i, j + 1)) return code, mu_code
def genmodel(model): assert type(model) == ModelDesc num_groups = len(model.groups) body = [] body.append( 'assert mode == "full" or mode == "prior_and_mu" or mode == "prior_only"' ) body.append('assert type(X) == onp.ndarray') body.append('N = X.shape[0]') # The number of columns in the design matrix. M = len(model.population.coefs) body.append('M = {}'.format(M)) body.append('assert X.shape == (N, M)') # Population level # -------------------------------------------------- # Prior over b. (The population level coefficients.) body.extend(genprior('b', contig(model.population.priors))) body.append('assert b.shape == (M,)') # Group level # -------------------------------------------------- mu_code = [] for i, group in enumerate(model.groups): grp_code, grp_mu_code = gengroup(i, group) body.extend(grp_code) mu_code.extend(grp_mu_code) # Compute mu. body.append('') body.append('if mode == "prior_only":') body.append(indent('mu = None')) body.append('else:') body.append(indent('mu = np.matmul(X, b)')) body.extend(indent(line) for line in mu_code) body.append('') # Response # -------------------------------------------------- # Sample from priors over the response distribution parameters # that aren't predicted from the data. for param, param_prior in zip(model.response.nonlocparams, model.response.priors): body.append( sample(param.name, gendist(param_prior, args(param_prior), [1]))) #body.append('with pyro.plate("obs", N):') # TODO: This condition allows us to run the model forward from # within `location` (the function that is part of the backend # interface) without having to worry about threading a RNG. I'd # rather not make this unnecessary check during inference and # might therefore revisit this approach. body.append('if mode == "full":') body.append(indent(sample('y', gen_response_dist(model), 'y_obs'))) # Values of interest that are not generated directly by sample # statements (such as the `b` vector) are returned from the model # so that they can be retrieved from the execution trace later. returned_params = (['mu', 'b'] + ['sd_{}'.format(i) for i in range(num_groups)] + ['r_{}'.format(i) for i in range(num_groups)]) retval = '{{{}}}'.format(', '.join('\'{}\': {}'.format(p, p) for p in returned_params)) body.append('return {}'.format(retval)) params = (['X'] + ['Z_{}'.format(i) for i in range(num_groups)] + ['J_{}'.format(i) for i in range(num_groups)] + ['y_obs=None', 'mode="full"']) return '\n'.join(method('model', params, body))