def test_transformed_parameter(): t = torch.tensor([1.0, 2.0]) p1 = Parameter('param', t) transformed = torch.distributions.ExpTransform() p2 = TransformedParameter('transformed', p1, transformed) assert p2.need_update is False assert torch.all(p2.tensor.eq(t.exp())) assert p2.need_update is False p1.tensor = torch.tensor([1.0, 3.0]) assert p2.need_update is True assert torch.all(p2.tensor.eq(p1.tensor.exp())) assert p2.need_update is False # jacobian p1.tensor = t assert p2.need_update is True assert torch.all(p2().eq(p1.tensor)) assert p2.need_update is False assert torch.all(p2.tensor.eq(p1.tensor.exp()))
def ratio_transform(args): replicates = args.replicates tree = read_tree(args.tree, True, True) taxa_count = len(tree.taxon_namespace) taxa = [] for node in tree.leaf_node_iter(): taxa.append(Taxon(node.label, {'date': node.date})) ratios_root_height = Parameter( "internal_heights", torch.tensor([0.5] * (taxa_count - 2) + [10]) ) tree_model = ReparameterizedTimeTreeModel( "tree", tree, Taxa('taxa', taxa), ratios_root_height ) ratios_root_height.tensor = tree_model.transform.inv( heights_from_branch_lengths(tree) ) @benchmark def fn(ratios_root_height): return tree_model.transform( ratios_root_height, ) @benchmark def fn_grad(ratios_root_height): heights = tree_model.transform( ratios_root_height, ) heights.backward(torch.ones_like(ratios_root_height)) ratios_root_height.grad.data.zero_() return heights total_time, heights = fn(args.replicates, ratios_root_height.tensor) print(f' {replicates} evaluations: {total_time}') ratios_root_height.requires_grad = True grad_total_time, heights = fn_grad(args.replicates, ratios_root_height.tensor) print(f' {replicates} gradient evaluations: {grad_total_time}') if args.output: args.output.write(f"ratio_transform,evaluation,off,{total_time},\n") args.output.write(f"ratio_transform,gradient,off,{grad_total_time},\n") print(' JIT off') @benchmark def fn2(ratios_root_height): return transform( tree_model.transform._forward_indices, tree_model.transform._bounds, ratios_root_height, ) @benchmark def fn2_grad(ratios_root_height): heights = transform( tree_model.transform._forward_indices, tree_model.transform._bounds, ratios_root_height, ) heights.backward(torch.ones_like(ratios_root_height)) ratios_root_height.grad.data.zero_() return heights ratios_root_height.requires_grad = False total_time, heights = fn2(args.replicates, ratios_root_height.tensor) print(f' {replicates} evaluations: {total_time}') ratios_root_height.requires_grad = True total_time, heights = fn2_grad(args.replicates, ratios_root_height.tensor) print(f' {replicates} gradient evaluations: {total_time}') print(' JIT on') transform_script = torch.jit.script(transform) @benchmark def fn2_jit(ratios_root_height): return transform_script( tree_model.transform._forward_indices, tree_model.transform._bounds, ratios_root_height, ) @benchmark def fn2_grad_jit(ratios_root_height): heights = transform_script( tree_model.transform._forward_indices, tree_model.transform._bounds, ratios_root_height, ) heights.backward(torch.ones_like(ratios_root_height)) ratios_root_height.grad.data.zero_() return heights ratios_root_height.requires_grad = False total_time, heights = fn2_jit(args.replicates, ratios_root_height.tensor) print(f' {replicates} evaluations: {total_time}') ratios_root_height.requires_grad = True total_time, heights = fn2_grad_jit(args.replicates, ratios_root_height.tensor) print(f' {replicates} gradient evaluations: {total_time}') print('ratio_transform v2 JIT off') @benchmark def fn3(ratios_root_height): return transform2( tree_model.transform._forward_indices.tolist(), tree_model.transform._bounds, ratios_root_height, ) @benchmark def fn3_grad(ratios_root_height): heights = transform2( tree_model.transform._forward_indices.tolist(), tree_model.transform._bounds, ratios_root_height, ) heights.backward(torch.ones_like(ratios_root_height)) ratios_root_height.grad.data.zero_() return heights ratios_root_height.requires_grad = False total_time, heights = fn3(args.replicates, ratios_root_height.tensor) print(f' {replicates} evaluations: {total_time}') ratios_root_height.requires_grad = True total_time, heights = fn3_grad(args.replicates, ratios_root_height.tensor) print(f' {replicates} gradient evaluations: {total_time}') print('ratio_transform v2 JIT on') transform2_script = torch.jit.script(transform2) @benchmark def fn3_jit(ratios_root_height): return transform2_script( tree_model.transform._forward_indices.tolist(), tree_model.transform._bounds, ratios_root_height, ) @benchmark def fn3_grad_jit(ratios_root_height): heights = transform2_script( tree_model.transform._forward_indices.tolist(), tree_model.transform._bounds, ratios_root_height, ) heights.backward(torch.ones_like(ratios_root_height)) ratios_root_height.grad.data.zero_() return heights ratios_root_height.requires_grad = False total_time, heights = fn3_jit(args.replicates, ratios_root_height.tensor) print(f' {replicates} evaluations: {total_time}') ratios_root_height.requires_grad = True total_time, heights = fn3_grad_jit(args.replicates, ratios_root_height.tensor) print(f' {replicates} gradient evaluations: {total_time}')
def ratio_transform_jacobian(args): tree = read_tree(args.tree, True, True) taxa = [] for node in tree.leaf_node_iter(): taxa.append(Taxon(node.label, {'date': node.date})) taxa_count = len(taxa) ratios_root_height = Parameter( "internal_heights", torch.tensor([0.5] * (taxa_count - 1) + [20]) ) tree_model = ReparameterizedTimeTreeModel( "tree", tree, Taxa('taxa', taxa), ratios_root_height ) ratios_root_height.tensor = tree_model.transform.inv( heights_from_branch_lengths(tree) ) @benchmark def fn(ratios_root_height): internal_heights = tree_model.transform(ratios_root_height) return tree_model.transform.log_abs_det_jacobian( ratios_root_height, internal_heights ) @benchmark def fn_grad(ratios_root_height): internal_heights = tree_model.transform(ratios_root_height) log_det_jac = tree_model.transform.log_abs_det_jacobian( ratios_root_height, internal_heights ) log_det_jac.backward() ratios_root_height.grad.data.zero_() return log_det_jac print(' JIT off') total_time, log_det_jac = fn(args.replicates, ratios_root_height.tensor) print(f' {args.replicates} evaluations: {total_time} ({log_det_jac})') ratios_root_height.requires_grad = True grad_total_time, grad_log_det_jac = fn_grad( args.replicates, ratios_root_height.tensor ) print( f' {args.replicates} gradient evaluations: {grad_total_time}' f' ({grad_log_det_jac})' ) if args.output: args.output.write( f"ratio_transform_jacobian,evaluation,off,{total_time}," f"{log_det_jac.squeeze().item()}\n" ) args.output.write( f"ratio_transform_jacobian,gradient,off,{grad_total_time}," f"{grad_log_det_jac.squeeze().item()}\n" ) if args.debug: internal_heights = tree_model.transform(ratios_root_height.tensor) log_det_jac = tree_model.transform.log_abs_det_jacobian( ratios_root_height.tensor, internal_heights ) log_det_jac.backward() print(ratios_root_height.grad)