Beispiel #1
0
    def interceptor(f, *args, **kwargs):
        """Sets random variable values to its aligned value."""

        _value = ed.interceptable(f)(*args, **kwargs).value

        conditional_value = tf.cond(var_condition, lambda: var_value,
                                    lambda: _value)

        return ed.interceptable(f)(*args,
                                   value=tf.broadcast_to(
                                       conditional_value, _value.shape),
                                   **kwargs)
Beispiel #2
0
    def interceptor(f, *args, **kwargs):
        """Sets random variable values to its aligned value."""

        # need to create the variable to obtain its value, and use it in the fn_false condition
        _value = ed.interceptable(f)(*args, **kwargs).value

        conditional_value = tf.cond(var_condition, lambda: var_value,
                                    lambda: _value)

        # need to broadcast to fix the shape of the tensor (which always be _value.shape)
        return ed.interceptable(f)(*args,
                                   value=tf.broadcast_to(
                                       conditional_value, _value.shape),
                                   **kwargs)
Beispiel #3
0
    def interceptor(f, *args, **kwargs):
        """Sets random variable values to its aligned value."""
        name = kwargs.get("name")
        if name in model_kwargs:
            kwargs["value"] = model_kwargs[name]

        return ed.interceptable(f)(*args, **kwargs)
Beispiel #4
0
 def interceptor(f, *args, **kwargs):
     """Sets random variable values to its aligned value."""
     name = kwargs.get("name")
     if name in model_kwargs:
         kwargs["value"] = model_kwargs[name]
     else:
         print(f"set_values not interested in {name}.")
     return ed.interceptable(f)(*args, **kwargs)
Beispiel #5
0
    def interceptor(f, *args, **kwargs):
        """Sets random variable values to its aligned value."""
        name = kwargs.get("name")

        # if name in model_kwargs, include the value as a new argument using model_kwargs[name]
        if name in model_kwargs:
            interception_value = model_kwargs[name]

            if ALLOW_CONDITIONS and CURRENT_ENABLE_INTERCEPTOR is not None:
                # local variable points to the real condition variable (created in the inference method object)
                # this way, even if CURRENT_ENABLE_INTERCEPTOR is set to None, this local variable points to the real one
                enable_globals, enable_locals = CURRENT_ENABLE_INTERCEPTOR
                # if any of them are None, set to constant False to work with the following tf.logical_and's
                if enable_globals is None:
                    enable_globals = tf.constant(False)
                if enable_locals is None:
                    enable_locals = tf.constant(False)

                # need to create the variable to obtain its value, and use it in the fn_false condition
                _value = ed.interceptable(f)(*args, **kwargs).value

                # Need to know if this variable is global hidden or local hidden. Can do it using the contextmanager
                is_local_hidden = data_model.is_active()

                conditional_value = tf.cond(
                    tf.logical_or(
                        tf.logical_and(enable_globals, tf.constant(not is_local_hidden)),
                        tf.logical_and(enable_locals, tf.constant(is_local_hidden))
                        ),
                    lambda: interception_value,
                    lambda: _value
                )

                # need to broadcast to fix the shape of the tensor (which always be _value.shape)
                kwargs['value'] = tf.broadcast_to(conditional_value, _value.shape)
            else:
                kwargs['value'] = interception_value

        return ed.interceptable(f)(*args, **kwargs)
Beispiel #6
0
 def set_values(f, *args, **kwargs):
     name = kwargs.get("name")
     if name in model_kwargs:
         kwargs["value"] = model_kwargs[name]
     return ed.interceptable(f)(*args, **kwargs)
Beispiel #7
0
 def set_xy(f, *args, **kwargs):
   if kwargs.get("name") == "x":
     kwargs["value"] = 1.
   if kwargs.get("name") == "y":
     kwargs["value"] = 0.42
   return ed.interceptable(f)(*args, **kwargs)
Beispiel #8
0
 def double(f, *args, **kwargs):
   return 2. * ed.interceptable(f)(*args, **kwargs)
Beispiel #9
0
 def trivial_interceptor(fn, *args, **kwargs):
   # An interceptor that does nothing.
   return ed.interceptable(fn)(*args, **kwargs)
 def set_values(f, *args, **kwargs):
   """Sets random variable values to its aligned value."""
   name = kwargs.get("name")
   if name in model_kwargs:
     kwargs["value"] = model_kwargs[name]
   return ed.interceptable(f)(*args, **kwargs)
 def set_xy(f, *args, **kwargs):
   if kwargs.get("name") == "x":
     kwargs["value"] = 1.
   if kwargs.get("name") == "y":
     kwargs["value"] = 0.42
   return ed.interceptable(f)(*args, **kwargs)
 def double(f, *args, **kwargs):
   return 2. * ed.interceptable(f)(*args, **kwargs)
 def trivial_interceptor(fn, *args, **kwargs):
   # An interceptor that does nothing.
   return ed.interceptable(fn)(*args, **kwargs)