def _get_flat_unconstraining_bijector(jd_model): """Create a bijector from a joint distribution that flattens and unconstrains. The intention is (loosely) to go from a model joint distribution supported on U_1 x U_2 x ... U_n, with U_j a subset of R^{n_j} to a model supported on R^N, with N = sum(n_j). (This is "loose" in the sense of base measures: some distribution may be supported on an m-dimensional subset of R^n, and the default transform for that distribution may then have support on R^m. See [1] for details. Args: jd_model: subclass of `tfd.JointDistribution` A JointDistribution for a model. Returns: A `tfb.Bijector` where the `.forward` method flattens and unconstrains points. """ # TODO(b/180396233): This bijector is in general point-dependent. to_chain = [jd_model.experimental_default_event_space_bijector()] flat_bijector = restructure.pack_sequence_as(jd_model.event_shape_tensor()) to_chain.append(flat_bijector) unconstrained_shapes = flat_bijector.inverse_event_shape_tensor( jd_model.event_shape_tensor()) # this reshaping is required as as split can produce a tensor of shape [1] # when the distribution event shape is [] reshapers = [ reshape.Reshape(event_shape_out=x, event_shape_in=[-1]) for x in unconstrained_shapes ] to_chain.append(joint_map.JointMap(bijectors=reshapers)) size_splits = [ps.reduce_prod(x) for x in unconstrained_shapes] to_chain.append(split.Split(num_or_size_splits=size_splits)) return invert.Invert(chain.Chain(to_chain))
def _get_flat_unconstraining_bijector(jd_model): """Create a bijector from a joint distribution that flattens and unconstrains. The intention is (loosely) to go from a model joint distribution supported on U_1 x U_2 x ... U_n, with U_j a subset of R^{n_j} to a model supported on R^N, with N = sum(n_j). (This is "loose" in the sense of base measures: some distribution may be supported on an m-dimensional subset of R^n, and the default transform for that distribution may then have support on R^m. See [1] for details. Args: jd_model: subclass of `tfd.JointDistribution` A JointDistribution for a model. Returns: Two `tfb.Bijector`s where the `.forward` method flattens and unconstrains points, and the second may be used to initialize a step size. """ # TODO(b/180396233): This bijector is in general point-dependent. event_space_bij = jd_model.experimental_default_event_space_bijector() flat_bijector = restructure.pack_sequence_as(jd_model.event_shape_tensor()) unconstrained_shapes = event_space_bij( flat_bijector).inverse_event_shape_tensor( jd_model.event_shape_tensor()) # this reshaping is required as as split can produce a tensor of shape [1] # when the distribution event shape is [] unsplit = joint_map.JointMap( tf.nest.map_structure( lambda x: reshape.Reshape(event_shape_out=x, event_shape_in=[-1]), unconstrained_shapes)) bij = invert.Invert(chain.Chain([event_space_bij, flat_bijector, unsplit])) step_size_bij = invert.Invert(flat_bijector) return bij, step_size_bij