Пример #1
0
def policy_loss_coeff_advantage_indicator(
        networks: CRRNetworks,
        policy_params: networks_lib.Params,
        critic_params: networks_lib.Params,
        transition: types.Transition,
        key: networks_lib.PRNGKey,
        num_action_samples: int = 4) -> jnp.ndarray:
    """Indicator advantage weighting; see equation (3) in CRR paper."""
    advantage = _compute_advantage(networks, policy_params, critic_params,
                                   transition, key, num_action_samples)
    return jnp.heaviside(advantage, 0.)
Пример #2
0
def step(x):
    """This is the unit step function, but the deriative is defined and equal to 0 at every point.

    Parameters
    ----------
    x : array-like
        Array to apply step to.


    Returns
    -------
    step_x : array-like
        step(x)
    """
    return np.heaviside(x, 0)
Пример #3
0
    def relax_plus_rebar_objective(encoder_params,
                                   decoder_params,
                                   surrogate_params,
                                   log_temperature,
                                   batch,
                                   prng_key,
                                   num_samples=1):
        """
        Computes the REBAR objective function of a discrete VAE.

        :param encoder_params: encoder parameters (list)
        :param decoder_params: decoder parameters (list)
        :param surrogate_params: surrogate parameters (list)
        :param log_temperature: log of inverse temperature
        :param batch: batch of data (jax.numpy array)
        :param prng_key: PRNG key
        :param num_samples: number of samples
        """
        temperature = np.exp(log_temperature)

        def concrete(x):
            return jax.nn.sigmoid(x * temperature)

        encoder1, encoder2 = encoder
        encoder_params1, encoder_params2 = encoder_params
        decoder1, decoder2 = decoder
        decoder_params1, decoder_params2 = decoder_params

        # Sampling
        sampling_key, control_variate_key = random.split(prng_key)
        first_layer_sampling_key, second_layer_sampling_key = random.split(
            sampling_key)
        first_layer_control_variate_key, second_layer_control_variate_key = random.split(
            control_variate_key)

        # Outer Layer
        first_layer_params = encoder1(encoder_params1, batch)  # BxD
        first_layer_relaxed_samples = binary_relaxed.sample(
            first_layer_params, first_layer_sampling_key, num_samples)
        first_layer_posterior_samples = np.heaviside(
            first_layer_relaxed_samples, 0)
        first_layer_conditional_relaxed_samples = binary_relaxed.conditional_sample(
            first_layer_params, first_layer_posterior_samples,
            first_layer_control_variate_key)

        # Inner Layer
        second_layer_params = encoder2(encoder_params2,
                                       first_layer_posterior_samples)
        second_layer_relaxed_samples = binary_relaxed.sample(
            second_layer_params, second_layer_sampling_key, num_samples=1)[0,
                                                                           ...]
        second_layer_posterior_samples = np.heaviside(
            second_layer_relaxed_samples, 0)
        second_layer_conditional_relaxed_samples = binary_relaxed.conditional_sample(
            second_layer_params, second_layer_posterior_samples,
            second_layer_control_variate_key)

        # Inner layer posterior
        log_posterior = np.sum(bernoulli.logpmf(second_layer_posterior_samples,
                                                second_layer_params),
                               axis=(1, 2))
        # Outer layer posterior
        log_posterior += np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))(
            first_layer_posterior_samples, first_layer_params),
                                axis=(1, 2))

        def elbo_samples(first_layer_samples,
                         second_layer_samples,
                         stop_grads=False):
            # Inner layer prior
            log_prior = np.sum(bernoulli.logpmf(
                second_layer_samples,
                decoder2(decoder_params2, first_layer_samples)),
                               axis=(1, 2))  # SxBxD
            # Outer layer prior
            log_prior += np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))(
                first_layer_samples, 0.5 * np.ones((batch.shape[0], 200))),
                                axis=(1, 2))  # SxBxD

            if stop_grads:
                # Inner layer posterior
                log_posterior = np.sum(bernoulli.logpmf(
                    second_layer_samples,
                    lax.stop_gradient(second_layer_params)),
                                       axis=(1, 2))
                # Outer layer posterior
                log_posterior += np.sum(vmap(
                    bernoulli.logpmf,
                    in_axes=(0, None))(first_layer_samples,
                                       lax.stop_gradient(first_layer_params)),
                                        axis=(1, 2))
            else:
                # Inner layer posterior
                log_posterior = np.sum(bernoulli.logpmf(
                    second_layer_samples, second_layer_params),
                                       axis=(1, 2))
                # Outer layer posterior
                log_posterior += np.sum(vmap(
                    bernoulli.logpmf, in_axes=(0, None))(first_layer_samples,
                                                         first_layer_params),
                                        axis=(1, 2))

            # Likelihood
            log_likelihood = vmap(bernoulli_log_likelihood,
                                  in_axes=(None, 0, None))(decoder_params1,
                                                           first_layer_samples,
                                                           batch)

            elbo_samples = log_likelihood - log_posterior + log_prior
            return elbo_samples

        elbo_evaluation = elbo_samples(first_layer_posterior_samples,
                                       second_layer_posterior_samples)
        unconditional_evaluation = elbo_samples(
            concrete(first_layer_relaxed_samples),
            concrete(second_layer_relaxed_samples))
        conditional_evaluation = elbo_samples(
            concrete(first_layer_conditional_relaxed_samples),
            concrete(second_layer_conditional_relaxed_samples))
        frozen_conditional_evaluation = elbo_samples(
            concrete(
                lax.stop_gradient(first_layer_conditional_relaxed_samples)),
            concrete(
                lax.stop_gradient(second_layer_conditional_relaxed_samples)),
            stop_grads=True)

        relaxed_surrogate_inputs = np.concatenate(
            [first_layer_relaxed_samples, second_layer_relaxed_samples],
            axis=-1)

        conditional_relaxed_surrogate_inputs = np.concatenate([
            first_layer_conditional_relaxed_samples,
            second_layer_conditional_relaxed_samples
        ],
                                                              axis=-1)

        obj = (lax.stop_gradient(elbo_evaluation) -
               frozen_conditional_evaluation -
               np.sum(surrogate(
                   surrogate_params,
                   lax.stop_gradient(conditional_relaxed_surrogate_inputs)),
                      axis=1).squeeze()) * log_posterior

        obj += unconditional_evaluation + np.sum(surrogate(
            surrogate_params, relaxed_surrogate_inputs),
                                                 axis=1).squeeze()
        obj -= conditional_evaluation + np.sum(surrogate(
            surrogate_params, conditional_relaxed_surrogate_inputs),
                                               axis=1).squeeze()

        return -np.mean(obj, axis=0) / batch.shape[0]
Пример #4
0
    def tunable_rebar_objective(encoder_params,
                                decoder_params,
                                batch,
                                prng_key,
                                cv_coeff=1.0,
                                log_temperature=0.5,
                                num_samples=1):
        """
        Computes the REBAR objective function of a discrete VAE.

        :param encoder_params: encoder parameters (list)
        :param decoder_params: decoder parameters (list)
        :param batch: batch of data (jax.numpy array)
        :param prng_key: PRNG key
        :param cv_coeff: control variate coefficient
        :param log_temperature: log_temperature parameter for rebar control variate.
        :param num_samples: number of samples
        """
        temperature = np.exp(log_temperature)
        encoder1, encoder2 = encoder
        encoder_params1, encoder_params2 = encoder_params
        decoder1, decoder2 = decoder
        decoder_params1, decoder_params2 = decoder_params

        # Sampling
        sampling_key, control_variate_key = random.split(prng_key)
        first_layer_sampling_key, second_layer_sampling_key = random.split(
            sampling_key)
        first_layer_control_variate_key, second_layer_control_variate_key = random.split(
            control_variate_key)

        # Outer Layer
        first_layer_params = encoder1(encoder_params1, batch)  # BxD
        first_layer_relaxed_samples = binary_relaxed.sample(
            first_layer_params, first_layer_sampling_key, num_samples)
        first_layer_posterior_samples = np.heaviside(
            first_layer_relaxed_samples, 0)
        first_layer_concrete_samples = jax.nn.sigmoid(
            first_layer_relaxed_samples * temperature)
        first_layer_conditional_relaxed_samples = binary_relaxed.conditional_sample(
            first_layer_params, first_layer_posterior_samples,
            first_layer_control_variate_key)
        first_layer_cv_samples = jax.nn.sigmoid(
            first_layer_conditional_relaxed_samples * temperature)
        first_layer_frozen_cv_samples = jax.nn.sigmoid(
            lax.stop_gradient(first_layer_conditional_relaxed_samples) *
            temperature)

        # Inner Layer
        second_layer_params = encoder2(encoder_params2,
                                       first_layer_posterior_samples)
        second_layer_relaxed_samples = binary_relaxed.sample(
            second_layer_params, second_layer_sampling_key, num_samples=1)[0,
                                                                           ...]
        second_layer_posterior_samples = np.heaviside(
            second_layer_relaxed_samples, 0)
        second_layer_concrete_samples = jax.nn.sigmoid(
            second_layer_relaxed_samples * temperature)
        second_layer_conditional_relaxed_samples = binary_relaxed.conditional_sample(
            second_layer_params, second_layer_posterior_samples,
            second_layer_control_variate_key)
        second_layer_cv_samples = jax.nn.sigmoid(
            second_layer_conditional_relaxed_samples * temperature)
        second_layer_frozen_cv_samples = jax.nn.sigmoid(
            lax.stop_gradient(second_layer_conditional_relaxed_samples) *
            temperature)

        # Inner layer posterior
        log_posterior = np.sum(bernoulli.logpmf(second_layer_posterior_samples,
                                                second_layer_params),
                               axis=(1, 2))
        # Outer layer posterior
        log_posterior += np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))(
            first_layer_posterior_samples, first_layer_params),
                                axis=(1, 2))

        def elbo_samples(first_layer_samples,
                         second_layer_samples,
                         stop_grads=False):
            # Inner layer prior
            log_prior = np.sum(bernoulli.logpmf(
                second_layer_samples,
                decoder2(decoder_params2, first_layer_samples)),
                               axis=(1, 2))  # SxBxD
            # Outer layer prior
            log_prior += np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))(
                first_layer_samples, 0.5 * np.ones((batch.shape[0], 200))),
                                axis=(1, 2))  # SxBxD

            if stop_grads:
                # Inner layer posterior
                log_posterior = np.sum(bernoulli.logpmf(
                    second_layer_samples,
                    lax.stop_gradient(second_layer_params)),
                                       axis=(1, 2))
                # Outer layer posterior
                log_posterior += np.sum(vmap(
                    bernoulli.logpmf,
                    in_axes=(0, None))(first_layer_samples,
                                       lax.stop_gradient(first_layer_params)),
                                        axis=(1, 2))
            else:
                # Inner layer posterior
                log_posterior = np.sum(bernoulli.logpmf(
                    second_layer_samples, second_layer_params),
                                       axis=(1, 2))
                # Outer layer posterior
                log_posterior += np.sum(vmap(
                    bernoulli.logpmf, in_axes=(0, None))(first_layer_samples,
                                                         first_layer_params),
                                        axis=(1, 2))

            # Likelihood
            log_likelihood = vmap(bernoulli_log_likelihood,
                                  in_axes=(None, 0, None))(decoder_params1,
                                                           first_layer_samples,
                                                           batch)

            elbo_samples = log_likelihood - log_posterior + log_prior
            return elbo_samples

        elbo_evaluation = elbo_samples(first_layer_posterior_samples,
                                       second_layer_posterior_samples)
        concrete_evaluation = elbo_samples(first_layer_concrete_samples,
                                           second_layer_concrete_samples)
        cv_evaluation = elbo_samples(first_layer_cv_samples,
                                     second_layer_cv_samples)
        frozen_cv_evaluation = elbo_samples(first_layer_frozen_cv_samples,
                                            second_layer_frozen_cv_samples,
                                            True)

        obj = (lax.stop_gradient(elbo_evaluation) -
               cv_coeff * frozen_cv_evaluation) * log_posterior

        obj += cv_coeff * (concrete_evaluation - cv_evaluation)

        return -np.mean(obj, axis=0) / batch.shape[0]
Пример #5
0
def heaviside(x1, x2):
  if isinstance(x1, JaxArray): x1 = x1.value
  if isinstance(x2, JaxArray): x2 = x2.value
  return JaxArray(jnp.heaviside(x1, x2))
Пример #6
0
    def relax_plus_rebar_objective(encoder_params,
                                   decoder_params,
                                   surrogate_params,
                                   log_temperature,
                                   batch,
                                   prng_key,
                                   num_samples=1):
        """
        Computes the REBAR objective function of a discrete VAE.

        :param encoder_params: encoder parameters (list)
        :param decoder_params: decoder parameters (list)
        :param surrogate_params: surrogate parameters (list)
        :param log_temperature: log of inverse temperature
        :param batch: batch of data (jax.numpy array)
        :param prng_key: PRNG key
        :param num_samples: number of samples
        """
        temperature = np.exp(log_temperature)

        def concrete(x):
            return jax.nn.sigmoid(x * temperature)

        posterior_params = encoder(encoder_params, batch)  # BxD

        # Sampling
        sampling_key, conditional_sampling_key = random.split(prng_key)

        # posterior_samples = bernoulli.sample(posterior_params, sampling_key, num_samples)
        unconditional_samples = binary_relaxed.sample(posterior_params,
                                                      sampling_key,
                                                      num_samples)
        posterior_samples = np.heaviside(unconditional_samples, 0)
        conditional_samples = binary_relaxed.conditional_sample(
            posterior_params, posterior_samples, conditional_sampling_key)

        log_posterior = np.sum(vmap(bernoulli.logpmf,
                                    in_axes=(0, None))(posterior_samples,
                                                       posterior_params),
                               axis=(1, 2))

        def elbo_samples(samples, stop_grads=False):
            log_prior = np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))(
                samples, 0.5 * np.ones(
                    (batch.shape[0], posterior_params.shape[-1]))),
                               axis=(1, 2))  # SxBxD

            if stop_grads:
                log_posterior = np.sum(vmap(
                    bernoulli.logpmf,
                    in_axes=(0, None))(samples,
                                       lax.stop_gradient(posterior_params)),
                                       axis=(1, 2))
            else:
                log_posterior = np.sum(vmap(bernoulli.logpmf,
                                            in_axes=(0,
                                                     None))(samples,
                                                            posterior_params),
                                       axis=(1, 2))
            log_likelihood = vmap(bernoulli_log_likelihood,
                                  in_axes=(None, 0, None))(decoder_params,
                                                           samples, batch)
            elbo_samples = log_likelihood - log_posterior + log_prior
            return elbo_samples

        elbo_evaluation = elbo_samples(posterior_samples)
        unconditional_evaluation = elbo_samples(
            concrete(unconditional_samples))
        conditional_evaluation = elbo_samples(concrete(conditional_samples))
        frozen_conditional_evaluation = elbo_samples(
            concrete(lax.stop_gradient(conditional_samples)), True)

        obj = (lax.stop_gradient(elbo_evaluation) -
               frozen_conditional_evaluation -
               np.sum(surrogate(surrogate_params,
                                lax.stop_gradient(conditional_samples)),
                      axis=1).squeeze()) * log_posterior

        obj += unconditional_evaluation + np.sum(surrogate(
            surrogate_params, unconditional_samples),
                                                 axis=1).squeeze()
        obj -= conditional_evaluation + np.sum(surrogate(
            surrogate_params, conditional_samples),
                                               axis=1).squeeze()

        return -np.mean(obj, axis=0) / batch.shape[0]
Пример #7
0
    def relax_objective(encoder_params,
                        decoder_params,
                        surrogate_params,
                        batch,
                        prng_key,
                        num_samples=1):
        """
        Computes the REBAR objective function of a discrete VAE.

        :param encoder_params: encoder parameters (list)
        :param decoder_params: decoder parameters (list)
        :param surrogate_params: surrogate parameters (list)
        :param batch: batch of data (jax.numpy array)
        :param prng_key: PRNG key
        :param num_samples: number of samples
        """
        posterior_params = encoder(encoder_params, batch)  # BxD

        # Sampling
        sampling_key, conditional_sampling_key = random.split(prng_key)

        # posterior_samples = bernoulli.sample(posterior_params, sampling_key, num_samples)
        unconditional_samples = binary_relaxed.sample(posterior_params,
                                                      sampling_key,
                                                      num_samples)
        posterior_samples = np.heaviside(unconditional_samples, 0)
        conditional_samples = binary_relaxed.conditional_sample(
            posterior_params, posterior_samples, conditional_sampling_key)

        log_posterior = np.sum(vmap(bernoulli.logpmf,
                                    in_axes=(0, None))(posterior_samples,
                                                       posterior_params),
                               axis=(1, 2))

        def elbo_samples(samples):
            log_prior = np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))(
                samples, 0.5 * np.ones(
                    (batch.shape[0], posterior_params.shape[-1]))),
                               axis=(1, 2))  # SxBxD
            log_posterior = np.sum(vmap(bernoulli.logpmf,
                                        in_axes=(0, None))(samples,
                                                           posterior_params),
                                   axis=(1, 2))
            log_likelihood = vmap(bernoulli_log_likelihood,
                                  in_axes=(None, 0, None))(decoder_params,
                                                           samples, batch)
            elbo_samples = log_likelihood - log_posterior + log_prior
            return elbo_samples

        elbo_evaluation = elbo_samples(posterior_samples)

        obj = (lax.stop_gradient(elbo_evaluation) -
               np.sum(surrogate(surrogate_params,
                                lax.stop_gradient(conditional_samples)),
                      axis=1).squeeze()) * log_posterior

        obj += np.sum(surrogate(surrogate_params, unconditional_samples) -
                      surrogate(surrogate_params, conditional_samples),
                      axis=1).squeeze()

        return -np.mean(obj, axis=0) / batch.shape[0]
Пример #8
0
    def tunable_rebar_objective(encoder_params,
                                decoder_params,
                                batch,
                                prng_key,
                                cv_coeff=1.0,
                                log_temperature=0.5,
                                num_samples=1):
        """
        Computes the REBAR objective function of a discrete VAE.

        :param encoder_params: encoder parameters (list)
        :param decoder_params: decoder parameters (list)
        :param batch: batch of data (jax.numpy array)
        :param prng_key: PRNG key
        :param cv_coeff: control variate coefficient
        :param log_temperature: log_temperature parameter for rebar control variate.
        :param num_samples: number of samples
        """
        temperature = np.exp(log_temperature)
        posterior_params = encoder(encoder_params, batch)  # BxD

        def concrete(x):
            return jax.nn.sigmoid(x * temperature)

        # Sampling
        sampling_key, control_variate_key = random.split(prng_key)

        # posterior_samples = bernoulli.sample(posterior_params, sampling_key, num_samples)
        relaxed_samples = binary_relaxed.sample(posterior_params, sampling_key,
                                                num_samples)

        posterior_samples = np.heaviside(relaxed_samples, 0)
        # concrete_samples = jax.nn.sigmoid(relaxed_samples * temperature)

        conditional_relaxed_samples = binary_relaxed.conditional_sample(
            posterior_params, posterior_samples, control_variate_key)
        # cv_samples = jax.nn.sigmoid(conditional_relaxed_samples * temperature)
        # frozen_cv_samples = binary_relaxed.conditional_sample(lax.stop_gradient(posterior_params), posterior_samples,
        #                                                       control_variate_key)
        # frozen_cv_samples = jax.nn.sigmoid(lax.stop_gradient(conditional_relaxed_samples) * temperature)

        log_posterior = np.sum(vmap(bernoulli.logpmf,
                                    in_axes=(0, None))(posterior_samples,
                                                       posterior_params),
                               axis=(1, 2))

        def elbo_samples(samples, stop_grads=False):
            log_prior = np.sum(vmap(bernoulli.logpmf, in_axes=(0, None))(
                samples, 0.5 * np.ones(
                    (batch.shape[0], posterior_params.shape[-1]))),
                               axis=(1, 2))  # SxBxD

            if stop_grads:
                log_posterior = np.sum(vmap(
                    bernoulli.logpmf,
                    in_axes=(0, None))(samples,
                                       lax.stop_gradient(posterior_params)),
                                       axis=(1, 2))
            else:
                log_posterior = np.sum(vmap(bernoulli.logpmf,
                                            in_axes=(0,
                                                     None))(samples,
                                                            posterior_params),
                                       axis=(1, 2))
            log_likelihood = vmap(bernoulli_log_likelihood,
                                  in_axes=(None, 0, None))(decoder_params,
                                                           samples, batch)
            elbo_samples = log_likelihood - log_posterior + log_prior
            return elbo_samples

        elbo_evaluation = elbo_samples(posterior_samples)
        concrete_evaluation = elbo_samples(concrete(relaxed_samples))
        cv_evaluation = elbo_samples((concrete(conditional_relaxed_samples)))
        frozen_cv_evaluation = elbo_samples(
            concrete(lax.stop_gradient(conditional_relaxed_samples)), True)

        obj = (lax.stop_gradient(elbo_evaluation) -
               cv_coeff * frozen_cv_evaluation) * log_posterior

        obj += cv_coeff * (concrete_evaluation - cv_evaluation)

        return -np.mean(obj, axis=0) / batch.shape[0]
Пример #9
0
 def f(x):
     return np.heaviside(x, 0.5)[0]
Пример #10
0
def To(xi,const):
    return np.heaviside(-(xi['beta']-np.pi),0.)*To1(xi,const)+np.heaviside(xi['beta']-np.pi,1.)*To2(xi,const)