def create_branch_model(id_, tree_id, taxa_count, arg): if arg.rate is not None: rate = [arg.rate] elif 'rate_init' in arg: rate = [arg.rate_init] else: rate = [0.001] rate_parameter = Parameter.json_factory(f'{id_}.rate', **{'tensor': rate}) rate_parameter['lower'] = 0.0 if arg.rate is not None: rate_parameter['lower'] = rate_parameter['upper'] = arg.rate if arg.clock == 'strict': return { 'id': id_, 'type': 'StrictClockModel', 'tree_model': tree_id, 'rate': rate_parameter, } elif arg.clock == 'horseshoe': rates = Parameter.json_factory( f'{id_}.rates.unscaled', **{'tensor': 1.0, 'full': [2 * taxa_count - 2]} ) rates['lower'] = 0.0 rescaled_rates = { 'id': f'{id_}.rates', 'type': 'TransformedParameter', 'transform': 'RescaledRateTransform', 'x': rates, 'parameters': { 'tree_model': tree_id, 'rate': rate_parameter, }, } return { 'id': f'{id_}.simple', 'type': 'SimpleClockModel', 'tree_model': tree_id, 'rate': rescaled_rates, } elif arg.clock == 'ucln': rate = Parameter.json_factory( f'{id_}.rates', **{'tensor': 0.001, 'full': [2 * taxa_count - 2]} ) rate['lower'] = 0.0 return { 'id': id_, 'type': 'SimpleClockModel', 'tree_model': tree_id, 'rate': rate, }
def create_birth_death(birth_death_id, tree_id, arg): R = Parameter.json_factory( f'{birth_death_id}.R', **{'tensor': 3.0, 'full': [arg.grid]}, ) R['lower'] = 0.0 delta = Parameter.json_factory( f'{birth_death_id}.delta', **{'tensor': 3.0, 'full': [arg.grid]}, ) delta['lower'] = 0.0 s = Parameter.json_factory( f'{birth_death_id}.s', **{'tensor': 0.0, 'full': [arg.grid]}, ) s['lower'] = 0.0 s['upper'] = 0.0 rho = Parameter.json_factory( f'{birth_death_id}.rho', **{ 'tensor': [1.0e-6], }, ) rho['lower'] = 0.0 rho['upper'] = 1.0 origin = { 'id': f'{birth_death_id}.origin', 'type': 'TransformedParameter', 'transform': 'torch.distributions.AffineTransform', 'x': { 'id': f'{birth_death_id}.origin.unshifted', 'type': 'Parameter', 'tensor': [1.0], 'lower': 0.0, }, 'parameters': { 'loc': f'{tree_id}.root_height', 'scale': 1.0, }, } bdsk = { 'id': birth_death_id, 'type': 'BDSKModel', 'tree_model': tree_id, 'R': R, 'delta': delta, 's': s, 'rho': rho, 'origin': origin, } return bdsk
def create_clock_horseshoe_prior(branch_model_id, tree_id): prior_list = [] log_diff = { 'id': f'{branch_model_id}.rates.logdiff', 'type': 'TransformedParameter', 'transform': 'LogDifferenceRateTransform', 'x': f'{branch_model_id}.rates.unscaled', 'parameters': { 'tree_model': tree_id }, } global_scale = Parameter.json_factory(f'{branch_model_id}.global.scale', **{'tensor': [1.0]}) local_scale = Parameter.json_factory( f'{branch_model_id}.local.scales', **{ 'tensor': 1.0, 'full_like': f'{branch_model_id}.rates.unscaled' }, ) global_scale['lower'] = 0.0 local_scale['lower'] = 0.0 prior_list.append( ScaleMixtureNormal.json_factory( f'{branch_model_id}.scale.mixture.prior', log_diff, 0.0, global_scale, local_scale, )) prior_list.append( CTMCScale.json_factory(f'{branch_model_id}.rate.prior', f'{branch_model_id}.rate', 'tree')) for p in ('global.scale', 'local.scales'): prior_list.append( Distribution.json_factory( f'{branch_model_id}.{p}.prior', 'torch.distributions.Cauchy', f'{branch_model_id}.{p}', { 'loc': 0.0, 'scale': 1.0 }, )) return prior_list
def create_gamma_distribution(var_id, x_unres, json_object, concentration, rate): concentration_param = { 'id': var_id + '.' + json_object['id'] + '.concentration', 'type': 'TransformedParameter', 'transform': 'torch.distributions.ExpTransform', 'x': Parameter.json_factory( var_id + '.' + json_object['id'] + '.concentration.unres', **{ 'full_like': json_object['id'], 'tensor': concentration }, ), } rate_param = { 'id': var_id + '.' + json_object['id'] + '.rate', 'type': 'TransformedParameter', 'transform': 'torch.distributions.ExpTransform', 'x': Parameter.json_factory( var_id + '.' + json_object['id'] + '.rate.unres', **{ 'full_like': json_object['id'], 'tensor': rate }, ), } distr = Distribution.json_factory( var_id + '.' + json_object['id'], 'torch.distributions.Gamma', x_unres, { 'concentration': concentration_param, 'rate': rate_param }, ) return distr, concentration_param, rate_param
def create_normal_distribution(var_id, x_unres, json_object, loc, scale): loc_param = Parameter.json_factory( var_id + '.' + json_object['id'] + '.loc', **{ 'full_like': x_unres, 'tensor': loc }, ) if isinstance(loc, list): del loc_param['full_like'] scale_param = { 'id': var_id + '.' + json_object['id'] + '.scale', 'type': 'TransformedParameter', 'transform': 'torch.distributions.ExpTransform', 'x': Parameter.json_factory( var_id + '.' + json_object['id'] + '.scale.unres', **{ 'full_like': x_unres, 'tensor': scale }, ), } if isinstance(scale, list): del scale_param['x']['full_like'] distr = Distribution.json_factory( var_id + '.' + json_object['id'], 'torch.distributions.Normal', x_unres, { 'loc': loc_param, 'scale': scale_param }, ) return distr, loc_param, scale_param
def create_site_model_srd06_mus(id_): weights = [2 / 3, 1 / 3] y = Parameter.json_factory('srd06.mu', **{'tensor': [0.5, 0.5]}) y['simplex'] = True mus = { 'id': id_, 'type': 'TransformedParameter', 'transform': 'ConvexCombinationTransform', 'x': y, 'parameters': {'weights': weights}, } return mus
def create_ucln_prior(branch_model_id): joint_list = [] mean = Parameter.json_factory( f'{branch_model_id}.rates.prior.mean', **{'tensor': [0.001]} ) scale = Parameter.json_factory( f'{branch_model_id}.rates.prior.scale', **{'tensor': [1.0]} ) mean['lower'] = 0.0 scale['lower'] = 0.0 joint_list.append( Distribution.json_factory( f'{branch_model_id}.rates.prior', 'LogNormal', f'{branch_model_id}.rates', { 'mean': mean, 'scale': scale, }, ) ) joint_list.append( CTMCScale.json_factory( f'{branch_model_id}.mean.prior', f'{branch_model_id}.rates.prior.mean', 'tree', ) ) joint_list.append( Distribution.json_factory( f'{branch_model_id}.rates.scale.prior', 'torch.distributions.Gamma', f'{branch_model_id}.rates.prior.scale', { 'concentration': 0.5396, 'rate': 2.6184, }, ) ) return joint_list
def create_site_model(id_, arg, w=None): if arg.categories == 1: site_model = {'id': id_, 'type': 'ConstantSiteModel'} else: shape = Parameter.json_factory(f'{id_}.shape', **{'tensor': [0.1]}) shape['lower'] = 0.0 site_model = { 'id': id_, 'type': 'WeibullSiteModel', 'categories': arg.categories, 'shape': shape, } if arg.model == 'SRD06': site_model['mu'] = w return site_model
def create_time_tree_prior(taxa, arg): prior = [] joint_list = [] if arg.coalescent is not None: params = {} coalescent_id = 'coalescent' if arg.coalescent in ('constant', 'exponential'): theta = Parameter.json_factory( f'{coalescent_id}.theta', **{'tensor': [3.0]} ) theta['lower'] = 0.0 joint_list.append( create_one_on_x_prior( f'{coalescent_id}.theta.prior', f'{coalescent_id}.theta' ) ) if arg.coalescent == 'exponential': growth = Parameter.json_factory( f'{coalescent_id}.growth', **{'tensor': [1.0]} ) params['growth'] = growth joint_list.append( Distribution.json_factory( f'{coalescent_id}.growth.prior', 'torch.distributions.Laplace', f'{coalescent_id}.growth', { 'loc': 0.0, 'scale': 1.0, }, ) ) elif arg.coalescent == 'skygrid': theta_log = Parameter.json_factory( f'{coalescent_id}.theta.log', **{'tensor': 3.0, 'full': [arg.grid]} ) theta = { 'id': f'{coalescent_id}.theta', 'type': 'TransformedParameter', 'transform': 'torch.distributions.ExpTransform', 'x': theta_log, } elif arg.coalescent == 'skyride': theta_log = Parameter.json_factory( f'{coalescent_id}.theta.log', **{'tensor': 3.0, 'full': [len(taxa['taxa']) - 1]}, ) theta = { 'id': f'{coalescent_id}.theta', 'type': 'TransformedParameter', 'transform': 'torch.distributions.ExpTransform', 'x': theta_log, } if arg.coalescent in ('skygrid', 'skyride'): gmrf = { 'id': 'gmrf', 'type': 'GMRF', 'x': f'{coalescent_id}.theta.log', 'precision': Parameter.json_factory( 'gmrf.precision', **{'tensor': [0.1]}, ), } if arg.time_aware: gmrf['tree_model'] = 'tree' gmrf['precision']['lower'] = 0.0 joint_list.append(gmrf) joint_list.append( Distribution.json_factory( 'gmrf.precision.prior', 'torch.distributions.Gamma', 'gmrf.precision', { 'concentration': 0.0010, 'rate': 0.0010, }, ) ) prior = create_coalesent(coalescent_id, 'tree', theta, arg, **params) elif arg.birth_death is not None: birth_death_id = 'bdsk' joint_list.append( Distribution.json_factory( f'{birth_death_id}.R.prior', 'torch.distributions.LogNormal', f'{birth_death_id}.R', { 'mean': 1.0, 'scale': 1.25, }, ), ) joint_list.append( Distribution.json_factory( f'{birth_death_id}.delta.prior', 'torch.distributions.LogNormal', f'{birth_death_id}.delta', { 'mean': 1.0, 'scale': 1.25, }, ), ) joint_list.append( Distribution.json_factory( f'{birth_death_id}.origin.prior', 'torch.distributions.LogNormal', f'{birth_death_id}.origin.unshifted', { 'mean': 1.0, 'scale': 1.25, }, ), ) joint_list.append( Distribution.json_factory( f'{birth_death_id}.rho.prior', 'torch.distributions.Beta', f'{birth_death_id}.rho', { 'concentration1': 1.0, 'concentration0': 9999.0, }, ), ) prior = create_birth_death(birth_death_id, 'tree', arg) return [prior] + joint_list
def create_substitution_model(id_, model, arg): if model == 'JC69': return {'id': id_, 'type': 'JC69'} elif model == 'HKY' or model == 'GTR': frequencies = Parameter.json_factory( f'{id_}.frequencies', **{'tensor': [0.25] * 4} ) frequencies['simplex'] = True alignment = None if arg.frequencies is not None: if arg.frequencies == 'empirical': alignment = build_alignment(arg.input, NucleotideDataType('')) frequencies['tensor'] = calculate_frequencies(alignment) elif arg.frequencies != 'equal': frequencies['tensor'] = list(map(float, arg.frequencies.split(','))) if len(frequencies['tensor']) != 4: raise ValueError( f'The dimension of the frequencies parameter ' f'({len(frequencies["tensor"])}) does not match the data type ' f'state count 4' ) if model == 'HKY': kappa = Parameter.json_factory(f'{id_}.kappa', **{'tensor': [3.0]}) kappa['lower'] = 0.0 if alignment is not None: kappa['tensor'] = [calculate_kappa(alignment, frequencies['tensor'])] return { 'id': id_, 'type': 'HKY', 'kappa': kappa, 'frequencies': frequencies, } else: rates = Parameter.json_factory( f'{id_}.rates', **{'tensor': 1 / 6, 'full': [6]} ) rates['simplex'] = True mapping = ((6, 0, 1, 2), (0, 6, 3, 4), (1, 3, 6, 5), (2, 4, 5, 6)) if alignment is not None: rel_rates = np.array(calculate_substitutions(alignment, mapping)) rates['tensor'] = (rel_rates[:-1] / rel_rates[:-1].sum()).tolist() return { 'id': id_, 'type': 'GTR', 'rates': rates, 'frequencies': frequencies, } elif model == 'MG94': alpha = Parameter.json_factory(f'{id_}.alpha', **{'tensor': [1.0]}) alpha['lower'] = 0.0 beta = Parameter.json_factory(f'{id_}.beta', **{'tensor': [1.0]}) beta['lower'] = 0.0 kappa = Parameter.json_factory(f'{id_}.kappa', **{'tensor': [1.0]}) kappa['lower'] = 0.0 data_type_json = 'data_type' if not hasattr(arg, '_data_type'): arg._data_type = create_data_type('data_type', arg) data_type_json = arg._data_type data_type = process_object(arg._data_type, {}) frequencies = Parameter.json_factory( f'{id_}.frequencies', **{'tensor': 1 / data_type.state_count, 'full': [data_type.state_count]}, ) # it is a simplex but it is fixed frequencies['lower'] = frequencies['upper'] = 1 if arg.frequencies is not None and arg.frequencies != 'equal': if arg.frequencies == 'F3x4': alignment = build_alignment(arg.input, data_type) frequencies['tensor'] = calculate_F3x4(alignment) else: frequencies['tensor'] = list(map(float, arg.frequencies.split(','))) if len(frequencies['tensor']) != data_type.state_count: raise ValueError( f'The dimension of the frequencies parameter ' f'({len(frequencies["tensor"])}) does not match the data type ' f'state count {data_type.state_count}' ) del frequencies['full'] return { 'id': id_, 'type': 'MG94', 'alpha': alpha, 'beta': beta, 'kappa': kappa, 'frequencies': frequencies, 'data_type': data_type_json, } elif model in ('LG', 'WAG'): return {'id': id_, 'type': model}
def create_tree_model(id_: str, taxa: dict, arg): tree_format = 'newick' with open(arg.tree) as fp: if next(fp).upper().startswith('#NEXUS'): tree_format = 'nexus' if tree_format == 'nexus': tree = Tree.get( path=arg.tree, schema=tree_format, tree_offset=0, preserve_underscores=True, ) newick = str(tree) + ';' else: with open(arg.tree) as fp: newick = fp.read() newick = newick.strip() kwargs = {} if arg.keep: kwargs['keep_branch_lengths'] = True if arg.clock is not None: dates = [taxon['attributes']['date'] for taxon in taxa['taxa']] offset = max(dates) - min(dates) if arg.heights == 'ratio': ratios = Parameter.json_factory( f'{id_}.ratios', **{'tensor': 0.1, 'full': [len(dates) - 2]} ) ratios['lower'] = 0.0 ratios['upper'] = 1.0 root_height = Parameter.json_factory( f'{id_}.root_height', **{'tensor': [offset + 1.0]} ) if 'root_height_init' in arg: root_height['tensor'] = [max(dates) - arg.root_height_init] elif arg.coalescent == 'skygrid': root_height['tensor'] = arg.cutoff root_height['lower'] = offset tree_model = ReparameterizedTimeTreeModel.json_factory( id_, newick, ratios, root_height, 'taxa', **kwargs ) elif arg.heights == 'shift': shifts = Parameter.json_factory( f'{id_}.shifts', **{'tensor': 0.1, 'full': [len(dates) - 1]} ) shifts['lower'] = 0.0 node_heights = { 'id': f'{id_}.heights', 'type': 'TransformedParameter', 'transform': 'torchtree.evolution.tree_height_transform' '.DifferenceNodeHeightTransform', 'x': shifts, 'parameters': {'tree_model': id_}, } tree_model = FlexibleTimeTreeModel.json_factory( id_, newick, node_heights, 'taxa', **kwargs ) else: branch_lengths = Parameter.json_factory( f'{id_}.blens', **{'tensor': 0.1, 'full': [len(taxa['taxa']) * 2 - 3]} ) branch_lengths['lower'] = 0.0 tree_model = UnRootedTreeModel.json_factory( id_, newick, branch_lengths, 'taxa', **kwargs ) return tree_model