Example #1
0
def create_branch_model(id_, tree_id, taxa_count, arg):
    if arg.rate is not None:
        rate = [arg.rate]
    elif 'rate_init' in arg:
        rate = [arg.rate_init]
    else:
        rate = [0.001]
    rate_parameter = Parameter.json_factory(f'{id_}.rate', **{'tensor': rate})
    rate_parameter['lower'] = 0.0

    if arg.rate is not None:
        rate_parameter['lower'] = rate_parameter['upper'] = arg.rate

    if arg.clock == 'strict':
        return {
            'id': id_,
            'type': 'StrictClockModel',
            'tree_model': tree_id,
            'rate': rate_parameter,
        }
    elif arg.clock == 'horseshoe':
        rates = Parameter.json_factory(
            f'{id_}.rates.unscaled', **{'tensor': 1.0, 'full': [2 * taxa_count - 2]}
        )
        rates['lower'] = 0.0
        rescaled_rates = {
            'id': f'{id_}.rates',
            'type': 'TransformedParameter',
            'transform': 'RescaledRateTransform',
            'x': rates,
            'parameters': {
                'tree_model': tree_id,
                'rate': rate_parameter,
            },
        }
        return {
            'id': f'{id_}.simple',
            'type': 'SimpleClockModel',
            'tree_model': tree_id,
            'rate': rescaled_rates,
        }
    elif arg.clock == 'ucln':
        rate = Parameter.json_factory(
            f'{id_}.rates', **{'tensor': 0.001, 'full': [2 * taxa_count - 2]}
        )
        rate['lower'] = 0.0
        return {
            'id': id_,
            'type': 'SimpleClockModel',
            'tree_model': tree_id,
            'rate': rate,
        }
Example #2
0
def create_birth_death(birth_death_id, tree_id, arg):
    R = Parameter.json_factory(
        f'{birth_death_id}.R',
        **{'tensor': 3.0, 'full': [arg.grid]},
    )
    R['lower'] = 0.0
    delta = Parameter.json_factory(
        f'{birth_death_id}.delta',
        **{'tensor': 3.0, 'full': [arg.grid]},
    )
    delta['lower'] = 0.0
    s = Parameter.json_factory(
        f'{birth_death_id}.s',
        **{'tensor': 0.0, 'full': [arg.grid]},
    )
    s['lower'] = 0.0
    s['upper'] = 0.0
    rho = Parameter.json_factory(
        f'{birth_death_id}.rho',
        **{
            'tensor': [1.0e-6],
        },
    )
    rho['lower'] = 0.0
    rho['upper'] = 1.0

    origin = {
        'id': f'{birth_death_id}.origin',
        'type': 'TransformedParameter',
        'transform': 'torch.distributions.AffineTransform',
        'x': {
            'id': f'{birth_death_id}.origin.unshifted',
            'type': 'Parameter',
            'tensor': [1.0],
            'lower': 0.0,
        },
        'parameters': {
            'loc': f'{tree_id}.root_height',
            'scale': 1.0,
        },
    }
    bdsk = {
        'id': birth_death_id,
        'type': 'BDSKModel',
        'tree_model': tree_id,
        'R': R,
        'delta': delta,
        's': s,
        'rho': rho,
        'origin': origin,
    }
    return bdsk
Example #3
0
def create_clock_horseshoe_prior(branch_model_id, tree_id):
    prior_list = []
    log_diff = {
        'id': f'{branch_model_id}.rates.logdiff',
        'type': 'TransformedParameter',
        'transform': 'LogDifferenceRateTransform',
        'x': f'{branch_model_id}.rates.unscaled',
        'parameters': {
            'tree_model': tree_id
        },
    }
    global_scale = Parameter.json_factory(f'{branch_model_id}.global.scale',
                                          **{'tensor': [1.0]})
    local_scale = Parameter.json_factory(
        f'{branch_model_id}.local.scales',
        **{
            'tensor': 1.0,
            'full_like': f'{branch_model_id}.rates.unscaled'
        },
    )
    global_scale['lower'] = 0.0
    local_scale['lower'] = 0.0
    prior_list.append(
        ScaleMixtureNormal.json_factory(
            f'{branch_model_id}.scale.mixture.prior',
            log_diff,
            0.0,
            global_scale,
            local_scale,
        ))
    prior_list.append(
        CTMCScale.json_factory(f'{branch_model_id}.rate.prior',
                               f'{branch_model_id}.rate', 'tree'))
    for p in ('global.scale', 'local.scales'):
        prior_list.append(
            Distribution.json_factory(
                f'{branch_model_id}.{p}.prior',
                'torch.distributions.Cauchy',
                f'{branch_model_id}.{p}',
                {
                    'loc': 0.0,
                    'scale': 1.0
                },
            ))
    return prior_list
Example #4
0
def create_gamma_distribution(var_id, x_unres, json_object, concentration,
                              rate):
    concentration_param = {
        'id':
        var_id + '.' + json_object['id'] + '.concentration',
        'type':
        'TransformedParameter',
        'transform':
        'torch.distributions.ExpTransform',
        'x':
        Parameter.json_factory(
            var_id + '.' + json_object['id'] + '.concentration.unres',
            **{
                'full_like': json_object['id'],
                'tensor': concentration
            },
        ),
    }
    rate_param = {
        'id':
        var_id + '.' + json_object['id'] + '.rate',
        'type':
        'TransformedParameter',
        'transform':
        'torch.distributions.ExpTransform',
        'x':
        Parameter.json_factory(
            var_id + '.' + json_object['id'] + '.rate.unres',
            **{
                'full_like': json_object['id'],
                'tensor': rate
            },
        ),
    }
    distr = Distribution.json_factory(
        var_id + '.' + json_object['id'],
        'torch.distributions.Gamma',
        x_unres,
        {
            'concentration': concentration_param,
            'rate': rate_param
        },
    )
    return distr, concentration_param, rate_param
Example #5
0
def create_normal_distribution(var_id, x_unres, json_object, loc, scale):
    loc_param = Parameter.json_factory(
        var_id + '.' + json_object['id'] + '.loc',
        **{
            'full_like': x_unres,
            'tensor': loc
        },
    )

    if isinstance(loc, list):
        del loc_param['full_like']

    scale_param = {
        'id':
        var_id + '.' + json_object['id'] + '.scale',
        'type':
        'TransformedParameter',
        'transform':
        'torch.distributions.ExpTransform',
        'x':
        Parameter.json_factory(
            var_id + '.' + json_object['id'] + '.scale.unres',
            **{
                'full_like': x_unres,
                'tensor': scale
            },
        ),
    }
    if isinstance(scale, list):
        del scale_param['x']['full_like']

    distr = Distribution.json_factory(
        var_id + '.' + json_object['id'],
        'torch.distributions.Normal',
        x_unres,
        {
            'loc': loc_param,
            'scale': scale_param
        },
    )
    return distr, loc_param, scale_param
Example #6
0
def create_site_model_srd06_mus(id_):
    weights = [2 / 3, 1 / 3]
    y = Parameter.json_factory('srd06.mu', **{'tensor': [0.5, 0.5]})
    y['simplex'] = True
    mus = {
        'id': id_,
        'type': 'TransformedParameter',
        'transform': 'ConvexCombinationTransform',
        'x': y,
        'parameters': {'weights': weights},
    }
    return mus
Example #7
0
def create_ucln_prior(branch_model_id):
    joint_list = []
    mean = Parameter.json_factory(
        f'{branch_model_id}.rates.prior.mean', **{'tensor': [0.001]}
    )
    scale = Parameter.json_factory(
        f'{branch_model_id}.rates.prior.scale', **{'tensor': [1.0]}
    )
    mean['lower'] = 0.0
    scale['lower'] = 0.0
    joint_list.append(
        Distribution.json_factory(
            f'{branch_model_id}.rates.prior',
            'LogNormal',
            f'{branch_model_id}.rates',
            {
                'mean': mean,
                'scale': scale,
            },
        )
    )
    joint_list.append(
        CTMCScale.json_factory(
            f'{branch_model_id}.mean.prior',
            f'{branch_model_id}.rates.prior.mean',
            'tree',
        )
    )
    joint_list.append(
        Distribution.json_factory(
            f'{branch_model_id}.rates.scale.prior',
            'torch.distributions.Gamma',
            f'{branch_model_id}.rates.prior.scale',
            {
                'concentration': 0.5396,
                'rate': 2.6184,
            },
        )
    )
    return joint_list
Example #8
0
def create_site_model(id_, arg, w=None):
    if arg.categories == 1:
        site_model = {'id': id_, 'type': 'ConstantSiteModel'}
    else:
        shape = Parameter.json_factory(f'{id_}.shape', **{'tensor': [0.1]})
        shape['lower'] = 0.0
        site_model = {
            'id': id_,
            'type': 'WeibullSiteModel',
            'categories': arg.categories,
            'shape': shape,
        }

    if arg.model == 'SRD06':
        site_model['mu'] = w
    return site_model
Example #9
0
def create_time_tree_prior(taxa, arg):
    prior = []
    joint_list = []
    if arg.coalescent is not None:
        params = {}
        coalescent_id = 'coalescent'
        if arg.coalescent in ('constant', 'exponential'):
            theta = Parameter.json_factory(
                f'{coalescent_id}.theta', **{'tensor': [3.0]}
            )
            theta['lower'] = 0.0

            joint_list.append(
                create_one_on_x_prior(
                    f'{coalescent_id}.theta.prior', f'{coalescent_id}.theta'
                )
            )
        if arg.coalescent == 'exponential':
            growth = Parameter.json_factory(
                f'{coalescent_id}.growth', **{'tensor': [1.0]}
            )
            params['growth'] = growth

            joint_list.append(
                Distribution.json_factory(
                    f'{coalescent_id}.growth.prior',
                    'torch.distributions.Laplace',
                    f'{coalescent_id}.growth',
                    {
                        'loc': 0.0,
                        'scale': 1.0,
                    },
                )
            )
        elif arg.coalescent == 'skygrid':
            theta_log = Parameter.json_factory(
                f'{coalescent_id}.theta.log', **{'tensor': 3.0, 'full': [arg.grid]}
            )
            theta = {
                'id': f'{coalescent_id}.theta',
                'type': 'TransformedParameter',
                'transform': 'torch.distributions.ExpTransform',
                'x': theta_log,
            }
        elif arg.coalescent == 'skyride':
            theta_log = Parameter.json_factory(
                f'{coalescent_id}.theta.log',
                **{'tensor': 3.0, 'full': [len(taxa['taxa']) - 1]},
            )
            theta = {
                'id': f'{coalescent_id}.theta',
                'type': 'TransformedParameter',
                'transform': 'torch.distributions.ExpTransform',
                'x': theta_log,
            }

        if arg.coalescent in ('skygrid', 'skyride'):
            gmrf = {
                'id': 'gmrf',
                'type': 'GMRF',
                'x': f'{coalescent_id}.theta.log',
                'precision': Parameter.json_factory(
                    'gmrf.precision',
                    **{'tensor': [0.1]},
                ),
            }
            if arg.time_aware:
                gmrf['tree_model'] = 'tree'
            gmrf['precision']['lower'] = 0.0
            joint_list.append(gmrf)
            joint_list.append(
                Distribution.json_factory(
                    'gmrf.precision.prior',
                    'torch.distributions.Gamma',
                    'gmrf.precision',
                    {
                        'concentration': 0.0010,
                        'rate': 0.0010,
                    },
                )
            )
        prior = create_coalesent(coalescent_id, 'tree', theta, arg, **params)
    elif arg.birth_death is not None:
        birth_death_id = 'bdsk'
        joint_list.append(
            Distribution.json_factory(
                f'{birth_death_id}.R.prior',
                'torch.distributions.LogNormal',
                f'{birth_death_id}.R',
                {
                    'mean': 1.0,
                    'scale': 1.25,
                },
            ),
        )
        joint_list.append(
            Distribution.json_factory(
                f'{birth_death_id}.delta.prior',
                'torch.distributions.LogNormal',
                f'{birth_death_id}.delta',
                {
                    'mean': 1.0,
                    'scale': 1.25,
                },
            ),
        )
        joint_list.append(
            Distribution.json_factory(
                f'{birth_death_id}.origin.prior',
                'torch.distributions.LogNormal',
                f'{birth_death_id}.origin.unshifted',
                {
                    'mean': 1.0,
                    'scale': 1.25,
                },
            ),
        )
        joint_list.append(
            Distribution.json_factory(
                f'{birth_death_id}.rho.prior',
                'torch.distributions.Beta',
                f'{birth_death_id}.rho',
                {
                    'concentration1': 1.0,
                    'concentration0': 9999.0,
                },
            ),
        )
        prior = create_birth_death(birth_death_id, 'tree', arg)

    return [prior] + joint_list
Example #10
0
def create_substitution_model(id_, model, arg):
    if model == 'JC69':
        return {'id': id_, 'type': 'JC69'}
    elif model == 'HKY' or model == 'GTR':
        frequencies = Parameter.json_factory(
            f'{id_}.frequencies', **{'tensor': [0.25] * 4}
        )
        frequencies['simplex'] = True
        alignment = None

        if arg.frequencies is not None:
            if arg.frequencies == 'empirical':
                alignment = build_alignment(arg.input, NucleotideDataType(''))
                frequencies['tensor'] = calculate_frequencies(alignment)
            elif arg.frequencies != 'equal':
                frequencies['tensor'] = list(map(float, arg.frequencies.split(',')))
                if len(frequencies['tensor']) != 4:
                    raise ValueError(
                        f'The dimension of the frequencies parameter '
                        f'({len(frequencies["tensor"])}) does not match the data type '
                        f'state count 4'
                    )

        if model == 'HKY':
            kappa = Parameter.json_factory(f'{id_}.kappa', **{'tensor': [3.0]})
            kappa['lower'] = 0.0
            if alignment is not None:
                kappa['tensor'] = [calculate_kappa(alignment, frequencies['tensor'])]
            return {
                'id': id_,
                'type': 'HKY',
                'kappa': kappa,
                'frequencies': frequencies,
            }
        else:
            rates = Parameter.json_factory(
                f'{id_}.rates', **{'tensor': 1 / 6, 'full': [6]}
            )
            rates['simplex'] = True
            mapping = ((6, 0, 1, 2), (0, 6, 3, 4), (1, 3, 6, 5), (2, 4, 5, 6))
            if alignment is not None:
                rel_rates = np.array(calculate_substitutions(alignment, mapping))
                rates['tensor'] = (rel_rates[:-1] / rel_rates[:-1].sum()).tolist()
            return {
                'id': id_,
                'type': 'GTR',
                'rates': rates,
                'frequencies': frequencies,
            }
    elif model == 'MG94':
        alpha = Parameter.json_factory(f'{id_}.alpha', **{'tensor': [1.0]})
        alpha['lower'] = 0.0
        beta = Parameter.json_factory(f'{id_}.beta', **{'tensor': [1.0]})
        beta['lower'] = 0.0
        kappa = Parameter.json_factory(f'{id_}.kappa', **{'tensor': [1.0]})
        kappa['lower'] = 0.0

        data_type_json = 'data_type'
        if not hasattr(arg, '_data_type'):
            arg._data_type = create_data_type('data_type', arg)
            data_type_json = arg._data_type
        data_type = process_object(arg._data_type, {})
        frequencies = Parameter.json_factory(
            f'{id_}.frequencies',
            **{'tensor': 1 / data_type.state_count, 'full': [data_type.state_count]},
        )
        # it is a simplex but it is fixed
        frequencies['lower'] = frequencies['upper'] = 1

        if arg.frequencies is not None and arg.frequencies != 'equal':
            if arg.frequencies == 'F3x4':
                alignment = build_alignment(arg.input, data_type)
                frequencies['tensor'] = calculate_F3x4(alignment)
            else:
                frequencies['tensor'] = list(map(float, arg.frequencies.split(',')))
                if len(frequencies['tensor']) != data_type.state_count:
                    raise ValueError(
                        f'The dimension of the frequencies parameter '
                        f'({len(frequencies["tensor"])}) does not match the data type '
                        f'state count {data_type.state_count}'
                    )
            del frequencies['full']

        return {
            'id': id_,
            'type': 'MG94',
            'alpha': alpha,
            'beta': beta,
            'kappa': kappa,
            'frequencies': frequencies,
            'data_type': data_type_json,
        }
    elif model in ('LG', 'WAG'):
        return {'id': id_, 'type': model}
Example #11
0
def create_tree_model(id_: str, taxa: dict, arg):
    tree_format = 'newick'
    with open(arg.tree) as fp:
        if next(fp).upper().startswith('#NEXUS'):
            tree_format = 'nexus'
    if tree_format == 'nexus':
        tree = Tree.get(
            path=arg.tree,
            schema=tree_format,
            tree_offset=0,
            preserve_underscores=True,
        )
        newick = str(tree) + ';'
    else:
        with open(arg.tree) as fp:
            newick = fp.read()
            newick = newick.strip()

    kwargs = {}
    if arg.keep:
        kwargs['keep_branch_lengths'] = True

    if arg.clock is not None:
        dates = [taxon['attributes']['date'] for taxon in taxa['taxa']]
        offset = max(dates) - min(dates)

        if arg.heights == 'ratio':
            ratios = Parameter.json_factory(
                f'{id_}.ratios', **{'tensor': 0.1, 'full': [len(dates) - 2]}
            )
            ratios['lower'] = 0.0
            ratios['upper'] = 1.0

            root_height = Parameter.json_factory(
                f'{id_}.root_height', **{'tensor': [offset + 1.0]}
            )
            if 'root_height_init' in arg:
                root_height['tensor'] = [max(dates) - arg.root_height_init]
            elif arg.coalescent == 'skygrid':
                root_height['tensor'] = arg.cutoff

            root_height['lower'] = offset
            tree_model = ReparameterizedTimeTreeModel.json_factory(
                id_, newick, ratios, root_height, 'taxa', **kwargs
            )
        elif arg.heights == 'shift':
            shifts = Parameter.json_factory(
                f'{id_}.shifts', **{'tensor': 0.1, 'full': [len(dates) - 1]}
            )
            shifts['lower'] = 0.0
            node_heights = {
                'id': f'{id_}.heights',
                'type': 'TransformedParameter',
                'transform': 'torchtree.evolution.tree_height_transform'
                '.DifferenceNodeHeightTransform',
                'x': shifts,
                'parameters': {'tree_model': id_},
            }
            tree_model = FlexibleTimeTreeModel.json_factory(
                id_, newick, node_heights, 'taxa', **kwargs
            )
    else:
        branch_lengths = Parameter.json_factory(
            f'{id_}.blens', **{'tensor': 0.1, 'full': [len(taxa['taxa']) * 2 - 3]}
        )
        branch_lengths['lower'] = 0.0
        tree_model = UnRootedTreeModel.json_factory(
            id_, newick, branch_lengths, 'taxa', **kwargs
        )
    return tree_model