def __init__( self, topology: TensorflowTreeTopology, height_bijector: Bijector, sampling_times: tp.Optional[tf.Tensor] = None, name="FixedTopologyRootedTreeBijector", validate_args=False, ): self.topology = topology self.height_bijector = height_bijector if sampling_times is None: self.sampling_times = tf.zeros(self.topology.taxon_count, dtype=height_bijector.dtype) else: self.sampling_times = sampling_times super().__init__( name=name, validate_args=validate_args, forward_min_event_ndims=1, inverse_min_event_ndims=TensorflowRootedTree( node_heights=1, sampling_times=1, topology=TensorflowTreeTopology.get_event_ndims(), ), dtype=self.height_bijector.dtype, )
def _static_topology_event_shape_tensor( taxon_count, ) -> TensorflowTreeTopology: taxon_count_tensor = tf.reshape(tf.convert_to_tensor(taxon_count), [1]) node_count = 2 * taxon_count_tensor - 1 return TensorflowTreeTopology( parent_indices=2 * taxon_count_tensor - 2, child_indices=tf.concat([node_count, [2]], axis=0), preorder_indices=node_count, )
def _parameter_properties( ls, dtype, num_classes=None ) -> tp.Dict[str, ParameterProperties]: return dict( transition_probs_tree=ParameterProperties( event_ndims=TensorflowUnrootedTree( branch_lengths=3, topology=TensorflowTreeTopology.get_event_ndims() ) ), frequencies=ParameterProperties(event_ndims=1), ) # TODO: shape_fn
def __init__( self, name="TopologyIdentityBijector", validate_args=False, prob_dtype=DEFAULT_FLOAT_DTYPE_TF, ): parameters = dict(locals()) with tf.name_scope(name) as name: super(Identity, self).__init__( forward_min_event_ndims=TensorflowTreeTopology.get_event_ndims( ), dtype=tf.int32, is_constant_jacobian=True, validate_args=validate_args, parameters=parameters, name=name, ) self.prob_dtype = prob_dtype # Override superclass private fields to eliminate caching, avoiding a memory # leak caused by the `y is x` characteristic of this bijector. self._from_x = self._from_y = _NoOpCache()
class BaseTreeDistribution(Distribution, tp.Generic[TTree]): _topology_reparameterization_type = TensorflowTreeTopology( reparameterization.NOT_REPARAMETERIZED, reparameterization.NOT_REPARAMETERIZED, reparameterization.NOT_REPARAMETERIZED, ) _topology_dtype = TensorflowTreeTopology(tf.int32, tf.int32, tf.int32) def __init__( self, taxon_count, reparameterization_type, dtype, validate_args=False, allow_nan_stats=True, tree_name: tp.Optional[str] = None, name="TreeDistribution", parameters=None, ): self.taxon_count = taxon_count self.tree_name = tree_name super().__init__( dtype=dtype, reparameterization_type=reparameterization_type, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name, parameters=parameters, ) def _call_sample_n(self, sample_shape, seed, name: str, **kwargs) -> TTree: """Wrapper around _sample_n.""" with self._name_and_control_scope(name): sample_shape = ps.convert_to_shape_tensor( ps.cast(sample_shape, tf.int32), name="sample_shape" ) sample_shape, n = self._expand_sample_shape_to_vector( sample_shape, "sample_shape" ) flat_samples = self._sample_n( n, seed=seed() if callable(seed) else seed, **kwargs ) def reshape_samples(sample_element): batch_event_shape = ps.shape(sample_element)[1:] final_shape = ps.concat([sample_shape, batch_event_shape], 0) return tf.reshape(sample_element, final_shape) samples = tf.nest.map_structure(reshape_samples, flat_samples) samples = self._set_sample_static_shape(samples, sample_shape) return samples def _call_log_prob(self, value, name, **kwargs): """Wrapper around _log_prob.""" value = tf.nest.pack_sequence_as(self.dtype, tf.nest.flatten(value)) value = nest_util.convert_to_nested_tensor( value, name="value", dtype_hint=self.dtype, allow_packing=True ) with self._name_and_control_scope(name, value, kwargs): if hasattr(self, "_log_prob"): return self._log_prob(value, **kwargs) if hasattr(self, "_prob"): return tf.math.log(self._prob(value, **kwargs)) raise NotImplementedError( "log_prob is not implemented: {}".format(type(self).__name__) ) def _call_unnormalized_log_prob(self, value, name, **kwargs): """Wrapper around _unnormalized_log_prob.""" value = tf.nest.pack_sequence_as(self.dtype, tf.nest.flatten(value)) value = nest_util.convert_to_nested_tensor( value, name="value", dtype_hint=self.dtype, allow_packing=True ) with self._name_and_control_scope(name, value, kwargs): if hasattr(self, "_unnormalized_log_prob"): return self._unnormalized_log_prob(value, **kwargs) if hasattr(self, "_unnormalized_prob"): return tf.math.log(self._unnormalized_prob(value, **kwargs)) if hasattr(self, "_log_prob"): return self._log_prob(value, **kwargs) if hasattr(self, "_prob"): return tf.math.log(self._prob(value, **kwargs)) raise NotImplementedError( "unnormalized_log_prob is not implemented: {}".format( type(self).__name__ ) ) @abstractmethod def _event_shape(self) -> TTree: pass @abstractmethod def _event_shape_tensor(self) -> TTree: pass @staticmethod def _static_topology_event_shape(taxon_count) -> TensorflowTreeTopology: return TensorflowTreeTopology( parent_indices=tf.TensorShape([2 * taxon_count - 2]), child_indices=tf.TensorShape([2 * taxon_count - 1, 2]), preorder_indices=tf.TensorShape([2 * taxon_count - 1]), ) def _topology_event_shape(self) -> TensorflowTreeTopology: taxon_count = self.taxon_count return type(self)._static_topology_event_shape(taxon_count) @staticmethod def _static_topology_event_shape_tensor( taxon_count, ) -> TensorflowTreeTopology: taxon_count_tensor = tf.reshape(tf.convert_to_tensor(taxon_count), [1]) node_count = 2 * taxon_count_tensor - 1 return TensorflowTreeTopology( parent_indices=2 * taxon_count_tensor - 2, child_indices=tf.concat([node_count, [2]], axis=0), preorder_indices=node_count, ) def _topology_event_shape_tensor(self) -> TensorflowTreeTopology: taxon_count = self.taxon_count return type(self)._static_topology_event_shape_tensor(taxon_count)
def _static_topology_event_shape(taxon_count) -> TensorflowTreeTopology: return TensorflowTreeTopology( parent_indices=tf.TensorShape([2 * taxon_count - 2]), child_indices=tf.TensorShape([2 * taxon_count - 1, 2]), preorder_indices=tf.TensorShape([2 * taxon_count - 1]), )
def test_TensorflowTreeTopology_has_rank_to_has_batch_dimensions_static( rank, expected): res = TensorflowTreeTopology.rank_to_has_batch_dimensions(rank) assert res == expected
parent_indices=flat_tree_test_data.parent_indices) tf_topology = numpy_topology_to_tensor(numpy_topology) rank = tf_topology.get_prefer_static_rank() assert isinstance(rank.parent_indices, np.ndarray) assert isinstance(rank.preorder_indices, np.ndarray) assert isinstance(rank.child_indices, np.ndarray) assert rank.parent_indices == 1 assert rank.preorder_indices == 1 assert rank.child_indices == 2 @pytest.mark.parametrize( ["rank", "expected"], [ ( TensorflowTreeTopology( parent_indices=1, preorder_indices=1, child_indices=2), False, ), # Static shape, no batch ( TensorflowTreeTopology( parent_indices=1, preorder_indices=1, child_indices=4), True, ), # Static shape, with batch ], ) # TODO: Tests for dynamic shape? def test_TensorflowTreeTopology_has_rank_to_has_batch_dimensions_static( rank, expected): res = TensorflowTreeTopology.rank_to_has_batch_dimensions(rank) assert res == expected