def recenter(rv_constructor, *rv_args, **rv_kwargs):
        if (rv_constructor.__name__ == 'Normal'
                and not rv_kwargs['name'].startswith('y')):

            # NB: assume everything is kwargs for now.
            x_loc = rv_kwargs['loc']
            x_scale = rv_kwargs['scale']

            name = rv_kwargs['name']
            shape = rv_constructor(*rv_args, **rv_kwargs).shape

            a, b = get_or_init(name, shape)  # w

            kwargs_std = {}
            kwargs_std['loc'] = tf.multiply(x_loc, a)
            kwargs_std['scale'] = tf.pow(x_scale, b)
            kwargs_std['name'] = name + '_param'

            scale = tf.pow(x_scale, 1. - b)
            b = tfb.AffineScalar(scale=scale,
                                 shift=x_loc +
                                 tf.multiply(scale, -kwargs_std['loc']))
            if 'value' in rv_kwargs:
                kwargs_std['value'] = b.inverse(rv_kwargs['value'])

            rv_std = interceptable(rv_constructor)(*rv_args, **kwargs_std)
            bijectors[name] = b
            return b.forward(rv_std)

        else:
            return interceptable(rv_constructor)(*rv_args, **rv_kwargs)
 def set_values(f, *args, **kwargs):
     """Sets random variable values to its aligned value."""
     name = kwargs.get('name')
     if name in model_kwargs:
         kwargs['value'] = model_kwargs[name]
     elif consumable_args:
         kwargs['value'] = consumable_args.pop(0)
     return interceptable(f)(*args, **kwargs)
def ncp(rv_constructor, *rv_args, **rv_kwargs):
    if (rv_constructor.__name__ == 'Normal'
            and not rv_kwargs['name'].startswith('y')):
        loc = rv_kwargs['loc']
        scale = rv_kwargs['scale']
        name = rv_kwargs['name']

        shape = rv_constructor(*rv_args, **rv_kwargs).shape

        kwargs_std = {}
        kwargs_std['loc'] = tf.zeros(shape)
        kwargs_std['scale'] = tf.ones(shape)
        kwargs_std['name'] = name + '_std'

        b = tfb.AffineScalar(scale=scale, shift=loc)
        if 'value' in rv_kwargs:
            kwargs_std['value'] = b.inverse(rv_kwargs['value'])

        rv_std = interceptable(rv_constructor)(*rv_args, **kwargs_std)

        return b.forward(rv_std)

    else:
        return interceptable(rv_constructor)(*rv_args, **rv_kwargs)
 def trace(rv_constructor, *rv_args, **rv_kwargs):
     rv = interceptable(rv_constructor)(*rv_args, **rv_kwargs)
     name = rv_kwargs['name']
     trace_result[name] = rv.value
     return rv
 def set_values(f, *args, **kwargs):
     """Sets random variable values to its aligned value."""
     name = kwargs.get("name")
     if name in model_kwargs:
         kwargs["value"] = model_kwargs[name]
     return interceptable(f)(*args, **kwargs)