def __init__(
     self,
     topology: TensorflowTreeTopology,
     height_bijector: Bijector,
     sampling_times: tp.Optional[tf.Tensor] = None,
     name="FixedTopologyRootedTreeBijector",
     validate_args=False,
 ):
     self.topology = topology
     self.height_bijector = height_bijector
     if sampling_times is None:
         self.sampling_times = tf.zeros(self.topology.taxon_count,
                                        dtype=height_bijector.dtype)
     else:
         self.sampling_times = sampling_times
     super().__init__(
         name=name,
         validate_args=validate_args,
         forward_min_event_ndims=1,
         inverse_min_event_ndims=TensorflowRootedTree(
             node_heights=1,
             sampling_times=1,
             topology=TensorflowTreeTopology.get_event_ndims(),
         ),
         dtype=self.height_bijector.dtype,
     )
 def _static_topology_event_shape_tensor(
     taxon_count,
 ) -> TensorflowTreeTopology:
     taxon_count_tensor = tf.reshape(tf.convert_to_tensor(taxon_count), [1])
     node_count = 2 * taxon_count_tensor - 1
     return TensorflowTreeTopology(
         parent_indices=2 * taxon_count_tensor - 2,
         child_indices=tf.concat([node_count, [2]], axis=0),
         preorder_indices=node_count,
     )
Beispiel #3
0
 def _parameter_properties(
     ls, dtype, num_classes=None
 ) -> tp.Dict[str, ParameterProperties]:
     return dict(
         transition_probs_tree=ParameterProperties(
             event_ndims=TensorflowUnrootedTree(
                 branch_lengths=3, topology=TensorflowTreeTopology.get_event_ndims()
             )
         ),
         frequencies=ParameterProperties(event_ndims=1),
     )  # TODO: shape_fn
    def __init__(
        self,
        name="TopologyIdentityBijector",
        validate_args=False,
        prob_dtype=DEFAULT_FLOAT_DTYPE_TF,
    ):
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            super(Identity, self).__init__(
                forward_min_event_ndims=TensorflowTreeTopology.get_event_ndims(
                ),
                dtype=tf.int32,
                is_constant_jacobian=True,
                validate_args=validate_args,
                parameters=parameters,
                name=name,
            )
        self.prob_dtype = prob_dtype

        # Override superclass private fields to eliminate caching, avoiding a memory
        # leak caused by the `y is x` characteristic of this bijector.
        self._from_x = self._from_y = _NoOpCache()
class BaseTreeDistribution(Distribution, tp.Generic[TTree]):
    _topology_reparameterization_type = TensorflowTreeTopology(
        reparameterization.NOT_REPARAMETERIZED,
        reparameterization.NOT_REPARAMETERIZED,
        reparameterization.NOT_REPARAMETERIZED,
    )
    _topology_dtype = TensorflowTreeTopology(tf.int32, tf.int32, tf.int32)

    def __init__(
        self,
        taxon_count,
        reparameterization_type,
        dtype,
        validate_args=False,
        allow_nan_stats=True,
        tree_name: tp.Optional[str] = None,
        name="TreeDistribution",
        parameters=None,
    ):
        self.taxon_count = taxon_count
        self.tree_name = tree_name
        super().__init__(
            dtype=dtype,
            reparameterization_type=reparameterization_type,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            name=name,
            parameters=parameters,
        )

    def _call_sample_n(self, sample_shape, seed, name: str, **kwargs) -> TTree:
        """Wrapper around _sample_n."""
        with self._name_and_control_scope(name):
            sample_shape = ps.convert_to_shape_tensor(
                ps.cast(sample_shape, tf.int32), name="sample_shape"
            )
            sample_shape, n = self._expand_sample_shape_to_vector(
                sample_shape, "sample_shape"
            )
            flat_samples = self._sample_n(
                n, seed=seed() if callable(seed) else seed, **kwargs
            )

            def reshape_samples(sample_element):
                batch_event_shape = ps.shape(sample_element)[1:]
                final_shape = ps.concat([sample_shape, batch_event_shape], 0)
                return tf.reshape(sample_element, final_shape)

            samples = tf.nest.map_structure(reshape_samples, flat_samples)
            samples = self._set_sample_static_shape(samples, sample_shape)
            return samples

    def _call_log_prob(self, value, name, **kwargs):
        """Wrapper around _log_prob."""
        value = tf.nest.pack_sequence_as(self.dtype, tf.nest.flatten(value))
        value = nest_util.convert_to_nested_tensor(
            value, name="value", dtype_hint=self.dtype, allow_packing=True
        )
        with self._name_and_control_scope(name, value, kwargs):
            if hasattr(self, "_log_prob"):
                return self._log_prob(value, **kwargs)
            if hasattr(self, "_prob"):
                return tf.math.log(self._prob(value, **kwargs))
            raise NotImplementedError(
                "log_prob is not implemented: {}".format(type(self).__name__)
            )

    def _call_unnormalized_log_prob(self, value, name, **kwargs):
        """Wrapper around _unnormalized_log_prob."""
        value = tf.nest.pack_sequence_as(self.dtype, tf.nest.flatten(value))
        value = nest_util.convert_to_nested_tensor(
            value, name="value", dtype_hint=self.dtype, allow_packing=True
        )
        with self._name_and_control_scope(name, value, kwargs):
            if hasattr(self, "_unnormalized_log_prob"):
                return self._unnormalized_log_prob(value, **kwargs)
            if hasattr(self, "_unnormalized_prob"):
                return tf.math.log(self._unnormalized_prob(value, **kwargs))
            if hasattr(self, "_log_prob"):
                return self._log_prob(value, **kwargs)
            if hasattr(self, "_prob"):
                return tf.math.log(self._prob(value, **kwargs))
            raise NotImplementedError(
                "unnormalized_log_prob is not implemented: {}".format(
                    type(self).__name__
                )
            )

    @abstractmethod
    def _event_shape(self) -> TTree:
        pass

    @abstractmethod
    def _event_shape_tensor(self) -> TTree:
        pass

    @staticmethod
    def _static_topology_event_shape(taxon_count) -> TensorflowTreeTopology:
        return TensorflowTreeTopology(
            parent_indices=tf.TensorShape([2 * taxon_count - 2]),
            child_indices=tf.TensorShape([2 * taxon_count - 1, 2]),
            preorder_indices=tf.TensorShape([2 * taxon_count - 1]),
        )

    def _topology_event_shape(self) -> TensorflowTreeTopology:
        taxon_count = self.taxon_count
        return type(self)._static_topology_event_shape(taxon_count)

    @staticmethod
    def _static_topology_event_shape_tensor(
        taxon_count,
    ) -> TensorflowTreeTopology:
        taxon_count_tensor = tf.reshape(tf.convert_to_tensor(taxon_count), [1])
        node_count = 2 * taxon_count_tensor - 1
        return TensorflowTreeTopology(
            parent_indices=2 * taxon_count_tensor - 2,
            child_indices=tf.concat([node_count, [2]], axis=0),
            preorder_indices=node_count,
        )

    def _topology_event_shape_tensor(self) -> TensorflowTreeTopology:
        taxon_count = self.taxon_count
        return type(self)._static_topology_event_shape_tensor(taxon_count)
 def _static_topology_event_shape(taxon_count) -> TensorflowTreeTopology:
     return TensorflowTreeTopology(
         parent_indices=tf.TensorShape([2 * taxon_count - 2]),
         child_indices=tf.TensorShape([2 * taxon_count - 1, 2]),
         preorder_indices=tf.TensorShape([2 * taxon_count - 1]),
     )
def test_TensorflowTreeTopology_has_rank_to_has_batch_dimensions_static(
        rank, expected):
    res = TensorflowTreeTopology.rank_to_has_batch_dimensions(rank)
    assert res == expected
        parent_indices=flat_tree_test_data.parent_indices)
    tf_topology = numpy_topology_to_tensor(numpy_topology)
    rank = tf_topology.get_prefer_static_rank()
    assert isinstance(rank.parent_indices, np.ndarray)
    assert isinstance(rank.preorder_indices, np.ndarray)
    assert isinstance(rank.child_indices, np.ndarray)
    assert rank.parent_indices == 1
    assert rank.preorder_indices == 1
    assert rank.child_indices == 2


@pytest.mark.parametrize(
    ["rank", "expected"],
    [
        (
            TensorflowTreeTopology(
                parent_indices=1, preorder_indices=1, child_indices=2),
            False,
        ),  # Static shape, no batch
        (
            TensorflowTreeTopology(
                parent_indices=1, preorder_indices=1, child_indices=4),
            True,
        ),  # Static shape, with batch
    ],
)  # TODO: Tests for dynamic shape?
def test_TensorflowTreeTopology_has_rank_to_has_batch_dimensions_static(
        rank, expected):
    res = TensorflowTreeTopology.rank_to_has_batch_dimensions(rank)
    assert res == expected