def pbroadcast_value(value, value_axis_names, output_axis_names): value_axis_names = distribute_lib.canonicalize_named_axis(value_axis_names) pbroadcast_axes = [ axis_name for axis_name in output_axis_names if axis_name not in value_axis_names ] return distribute_lib.pbroadcast(value, named_axis=pbroadcast_axes)
def manual_sharded_model(): # This one has manual pbroadcasts; the goal is to get sharded_model above # to do this automatically. x = yield root(tfd.LogNormal(0., 1., name='x')) x = distribute_lib.pbroadcast(x, axis_name=self.axis_name) yield sharded.Sharded( tfd.Uniform(0., x), shard_axis_name=self.axis_name, name='y') yield sharded.Sharded( tfb.Scale(x)(tfd.Normal(0., 1.)), shard_axis_name=self.axis_name, name='z')
def target_log_prob(a, b): return ( tfd.Normal(0., 1.).log_prob(a) + distribute_lib.psum(tfd.Normal( distribute_lib.pbroadcast(a, 'foo'), 1.).log_prob(b), 'foo'))
def adjust_state(x, v, shard_axes=None): broadcasted_dt = distribute_lib.pbroadcast( bu.left_justified_expand_dims_like(dt, v), shard_axes) return x + broadcasted_dt * v