Ejemplo n.º 1
0
def pbroadcast_value(value, value_axis_names, output_axis_names):
    value_axis_names = distribute_lib.canonicalize_named_axis(value_axis_names)
    pbroadcast_axes = [
        axis_name for axis_name in output_axis_names
        if axis_name not in value_axis_names
    ]
    return distribute_lib.pbroadcast(value, named_axis=pbroadcast_axes)
Ejemplo n.º 2
0
 def manual_sharded_model():
   # This one has manual pbroadcasts; the goal is to get sharded_model above
   # to do this automatically.
   x = yield root(tfd.LogNormal(0., 1., name='x'))
   x = distribute_lib.pbroadcast(x, axis_name=self.axis_name)
   yield sharded.Sharded(
       tfd.Uniform(0., x), shard_axis_name=self.axis_name, name='y')
   yield sharded.Sharded(
       tfb.Scale(x)(tfd.Normal(0., 1.)),
       shard_axis_name=self.axis_name,
       name='z')
Ejemplo n.º 3
0
 def target_log_prob(a, b):
   return (
       tfd.Normal(0., 1.).log_prob(a)
       + distribute_lib.psum(tfd.Normal(
           distribute_lib.pbroadcast(a, 'foo'), 1.).log_prob(b), 'foo'))
 def adjust_state(x, v, shard_axes=None):
   broadcasted_dt = distribute_lib.pbroadcast(
       bu.left_justified_expand_dims_like(dt, v), shard_axes)
   return x + broadcasted_dt * v