示例#1
0
def test_transformed_parameter():
    t = torch.tensor([1.0, 2.0])
    p1 = Parameter('param', t)
    transformed = torch.distributions.ExpTransform()
    p2 = TransformedParameter('transformed', p1, transformed)
    assert p2.need_update is False
    assert torch.all(p2.tensor.eq(t.exp()))
    assert p2.need_update is False

    p1.tensor = torch.tensor([1.0, 3.0])
    assert p2.need_update is True

    assert torch.all(p2.tensor.eq(p1.tensor.exp()))
    assert p2.need_update is False

    # jacobian
    p1.tensor = t
    assert p2.need_update is True
    assert torch.all(p2().eq(p1.tensor))
    assert p2.need_update is False

    assert torch.all(p2.tensor.eq(p1.tensor.exp()))
示例#2
0
def ratio_transform(args):
    replicates = args.replicates
    tree = read_tree(args.tree, True, True)
    taxa_count = len(tree.taxon_namespace)
    taxa = []
    for node in tree.leaf_node_iter():
        taxa.append(Taxon(node.label, {'date': node.date}))
    ratios_root_height = Parameter(
        "internal_heights", torch.tensor([0.5] * (taxa_count - 2) + [10])
    )
    tree_model = ReparameterizedTimeTreeModel(
        "tree", tree, Taxa('taxa', taxa), ratios_root_height
    )

    ratios_root_height.tensor = tree_model.transform.inv(
        heights_from_branch_lengths(tree)
    )

    @benchmark
    def fn(ratios_root_height):
        return tree_model.transform(
            ratios_root_height,
        )

    @benchmark
    def fn_grad(ratios_root_height):
        heights = tree_model.transform(
            ratios_root_height,
        )
        heights.backward(torch.ones_like(ratios_root_height))
        ratios_root_height.grad.data.zero_()
        return heights

    total_time, heights = fn(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} evaluations: {total_time}')

    ratios_root_height.requires_grad = True
    grad_total_time, heights = fn_grad(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} gradient evaluations: {grad_total_time}')

    if args.output:
        args.output.write(f"ratio_transform,evaluation,off,{total_time},\n")
        args.output.write(f"ratio_transform,gradient,off,{grad_total_time},\n")

    print('  JIT off')

    @benchmark
    def fn2(ratios_root_height):
        return transform(
            tree_model.transform._forward_indices,
            tree_model.transform._bounds,
            ratios_root_height,
        )

    @benchmark
    def fn2_grad(ratios_root_height):
        heights = transform(
            tree_model.transform._forward_indices,
            tree_model.transform._bounds,
            ratios_root_height,
        )
        heights.backward(torch.ones_like(ratios_root_height))
        ratios_root_height.grad.data.zero_()
        return heights

    ratios_root_height.requires_grad = False
    total_time, heights = fn2(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} evaluations: {total_time}')

    ratios_root_height.requires_grad = True
    total_time, heights = fn2_grad(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} gradient evaluations: {total_time}')

    print('  JIT on')
    transform_script = torch.jit.script(transform)

    @benchmark
    def fn2_jit(ratios_root_height):
        return transform_script(
            tree_model.transform._forward_indices,
            tree_model.transform._bounds,
            ratios_root_height,
        )

    @benchmark
    def fn2_grad_jit(ratios_root_height):
        heights = transform_script(
            tree_model.transform._forward_indices,
            tree_model.transform._bounds,
            ratios_root_height,
        )
        heights.backward(torch.ones_like(ratios_root_height))
        ratios_root_height.grad.data.zero_()
        return heights

    ratios_root_height.requires_grad = False
    total_time, heights = fn2_jit(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} evaluations: {total_time}')

    ratios_root_height.requires_grad = True
    total_time, heights = fn2_grad_jit(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} gradient evaluations: {total_time}')

    print('ratio_transform v2 JIT off')

    @benchmark
    def fn3(ratios_root_height):
        return transform2(
            tree_model.transform._forward_indices.tolist(),
            tree_model.transform._bounds,
            ratios_root_height,
        )

    @benchmark
    def fn3_grad(ratios_root_height):
        heights = transform2(
            tree_model.transform._forward_indices.tolist(),
            tree_model.transform._bounds,
            ratios_root_height,
        )
        heights.backward(torch.ones_like(ratios_root_height))
        ratios_root_height.grad.data.zero_()
        return heights

    ratios_root_height.requires_grad = False
    total_time, heights = fn3(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} evaluations: {total_time}')

    ratios_root_height.requires_grad = True
    total_time, heights = fn3_grad(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} gradient evaluations: {total_time}')

    print('ratio_transform v2 JIT on')
    transform2_script = torch.jit.script(transform2)

    @benchmark
    def fn3_jit(ratios_root_height):
        return transform2_script(
            tree_model.transform._forward_indices.tolist(),
            tree_model.transform._bounds,
            ratios_root_height,
        )

    @benchmark
    def fn3_grad_jit(ratios_root_height):
        heights = transform2_script(
            tree_model.transform._forward_indices.tolist(),
            tree_model.transform._bounds,
            ratios_root_height,
        )
        heights.backward(torch.ones_like(ratios_root_height))
        ratios_root_height.grad.data.zero_()
        return heights

    ratios_root_height.requires_grad = False
    total_time, heights = fn3_jit(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} evaluations: {total_time}')

    ratios_root_height.requires_grad = True
    total_time, heights = fn3_grad_jit(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} gradient evaluations: {total_time}')
示例#3
0
def ratio_transform_jacobian(args):
    tree = read_tree(args.tree, True, True)
    taxa = []
    for node in tree.leaf_node_iter():
        taxa.append(Taxon(node.label, {'date': node.date}))
    taxa_count = len(taxa)
    ratios_root_height = Parameter(
        "internal_heights", torch.tensor([0.5] * (taxa_count - 1) + [20])
    )
    tree_model = ReparameterizedTimeTreeModel(
        "tree", tree, Taxa('taxa', taxa), ratios_root_height
    )

    ratios_root_height.tensor = tree_model.transform.inv(
        heights_from_branch_lengths(tree)
    )

    @benchmark
    def fn(ratios_root_height):
        internal_heights = tree_model.transform(ratios_root_height)
        return tree_model.transform.log_abs_det_jacobian(
            ratios_root_height, internal_heights
        )

    @benchmark
    def fn_grad(ratios_root_height):
        internal_heights = tree_model.transform(ratios_root_height)
        log_det_jac = tree_model.transform.log_abs_det_jacobian(
            ratios_root_height, internal_heights
        )
        log_det_jac.backward()
        ratios_root_height.grad.data.zero_()
        return log_det_jac

    print('  JIT off')
    total_time, log_det_jac = fn(args.replicates, ratios_root_height.tensor)
    print(f'  {args.replicates} evaluations: {total_time} ({log_det_jac})')

    ratios_root_height.requires_grad = True
    grad_total_time, grad_log_det_jac = fn_grad(
        args.replicates, ratios_root_height.tensor
    )
    print(
        f'  {args.replicates} gradient evaluations: {grad_total_time}'
        f' ({grad_log_det_jac})'
    )

    if args.output:
        args.output.write(
            f"ratio_transform_jacobian,evaluation,off,{total_time},"
            f"{log_det_jac.squeeze().item()}\n"
        )
        args.output.write(
            f"ratio_transform_jacobian,gradient,off,{grad_total_time},"
            f"{grad_log_det_jac.squeeze().item()}\n"
        )

    if args.debug:
        internal_heights = tree_model.transform(ratios_root_height.tensor)
        log_det_jac = tree_model.transform.log_abs_det_jacobian(
            ratios_root_height.tensor, internal_heights
        )
        log_det_jac.backward()
        print(ratios_root_height.grad)