Beispiel #1
0
def _get_mdn_coef(output):
    logmix, mean, logstd = tf.split(output, 3, -1)
    logmix = logmix - tf.reduce_logsumexp(logmix, -1, keepdims=True)
    return logmix, mean, logstd
Beispiel #2
0
 def softmax_loss(self, antecedent_scores, antecedent_labels):
     gold_scores = antecedent_scores + \
         tf.log(tf.to_float(antecedent_labels))  # [k, max_ant + 1]
     marginalized_gold_scores = tf.reduce_logsumexp(gold_scores, [1])  # [k]
     log_norm = tf.reduce_logsumexp(antecedent_scores, [1])  # [k]
     return log_norm - marginalized_gold_scores  # [k]
Beispiel #3
0
    def __init__(self,
                 env_spec,
                 expert_trajs=None,
                 discrim_arch=feedforward_energy,
                 discrim_arch_args={},
                 l2_reg=0,
                 discount=1.0,
                 init_itrs=None,
                 score_dtau=False,
                 state_only=False,
                 name='trajprior'):
        super(GAN_GCL, self).__init__()
        self.dO = env_spec.observation_space.flat_dim
        self.dU = env_spec.action_space.flat_dim
        self.score_dtau = score_dtau
        self.set_demos(expert_trajs)

        # build energy model
        with tf.variable_scope(name) as vs:
            # Should be batch_size x T x dO/dU
            self.obs_t = tf.placeholder(tf.float32, [None, None, self.dO],
                                        name='obs')
            self.act_t = tf.placeholder(tf.float32, [None, None, self.dU],
                                        name='act')
            self.traj_logprobs = tf.placeholder(tf.float32, [None, None],
                                                name='traj_probs')
            self.labels = tf.placeholder(tf.float32, [None, 1], name='labels')
            self.lr = tf.placeholder(tf.float32, (), name='lr')

            if state_only:
                obs_act = self.obs_t
            else:
                obs_act = tf.concat([self.obs_t, self.act_t], axis=2)

            with tf.variable_scope('discrim') as vs2:
                self.energy = discrim_arch(obs_act, **discrim_arch_args)
                discrim_vars = tf.get_collection(
                    tf.GraphKeys.TRAINABLE_VARIABLES, scope=vs2.name)

            self.energy_timestep = self.energy
            # Don't train separate log Z because we can't fully separate it from the energy function
            if discount >= 1.0:
                log_p_tau = tf.reduce_sum(-self.energy, axis=1)
            else:
                log_p_tau = discounted_reduce_sum(-self.energy,
                                                  discount=discount,
                                                  axis=1)
            log_q_tau = tf.reduce_sum(self.traj_logprobs,
                                      axis=1,
                                      keep_dims=True)

            # numerical stability trick
            log_pq = tf.reduce_logsumexp([log_p_tau, log_q_tau], axis=0)
            self.d_tau = tf.exp(log_p_tau - log_pq)
            cent_loss = -tf.reduce_mean(self.labels * (log_p_tau - log_pq) +
                                        (1 - self.labels) *
                                        (log_q_tau - log_pq))

            if l2_reg > 0:
                reg_loss = l2_reg * tf.reduce_sum(
                    [tf.reduce_sum(tf.square(var)) for var in discrim_vars])
            else:
                reg_loss = 0

            #self.predictions = tf.nn.sigmoid(logits)
            self.loss = cent_loss + reg_loss
            self.step = tf.train.AdamOptimizer(learning_rate=self.lr).minimize(
                self.loss)
            self._make_param_ops(vs)
Beispiel #4
0
    def _build(self, x, presence=None):

        # x is [B, n_input_points, n_input_dims]
        batch_size, n_input_points = x.shape[:2].as_list()

        # votes and scale have shape [B, n_caps, n_input_points, n_input_dims|1]
        # since scale is a per-caps scalar and we have one vote per capsule
        vote_component_pdf = self._get_pdf(self._votes,
                                           tf.expand_dims(self._scales, -1))

        # expand along caps dimensions -> [B, 1, n_input_points, n_input_dims]
        expanded_x = tf.expand_dims(x, 1)
        vote_log_prob_per_dim = vote_component_pdf.log_prob(expanded_x)
        # [B, n_caps, n_input_points]
        vote_log_prob = tf.reduce_sum(vote_log_prob_per_dim, -1)
        dummy_vote_log_prob = tf.zeros([batch_size, 1, n_input_points])
        dummy_vote_log_prob -= 2. * tf.log(10.)

        # [B, n_caps + 1, n_input_points]
        vote_log_prob = tf.concat([vote_log_prob, dummy_vote_log_prob], 1)

        # [B, n_caps, n_input_points]
        mixing_logits = math_ops.safe_log(self._vote_presence_prob)

        dummy_logit = tf.zeros([batch_size, 1, 1]) - 2. * tf.log(10.)
        dummy_logit = snt.TileByDim([2], [n_input_points])(dummy_logit)

        # [B, n_caps + 1, n_input_points]
        mixing_logits = tf.concat([mixing_logits, dummy_logit], 1)
        mixing_log_prob = mixing_logits - tf.reduce_logsumexp(
            mixing_logits, 1, keepdims=True)
        # [B, n_input_points]
        mixture_log_prob_per_point = tf.reduce_logsumexp(
            mixing_logits + vote_log_prob, 1)

        if presence is not None:
            presence = tf.to_float(presence)
            mixture_log_prob_per_point *= presence

        # [B,]
        mixture_log_prob_per_example\
          = tf.reduce_sum(mixture_log_prob_per_point, 1)

        # []
        mixture_log_prob_per_batch = tf.reduce_mean(
            mixture_log_prob_per_example)

        # [B, n_caps + 1, n_input_points]
        posterior_mixing_logits_per_point = mixing_logits + vote_log_prob

        # [B, n_input_points]
        winning_vote_idx = tf.argmax(posterior_mixing_logits_per_point[:, :-1],
                                     1)

        batch_idx = tf.expand_dims(tf.range(batch_size, dtype=tf.int64), 1)
        batch_idx = snt.TileByDim([1], [n_input_points])(batch_idx)

        point_idx = tf.expand_dims(tf.range(n_input_points, dtype=tf.int64), 0)
        point_idx = snt.TileByDim([0], [batch_size])(point_idx)

        idx = tf.stack([batch_idx, winning_vote_idx, point_idx], -1)
        winning_vote = tf.gather_nd(self._votes, idx)
        winning_pres = tf.gather_nd(self._vote_presence_prob, idx)
        vote_presence = tf.greater(mixing_logits[:, :-1], mixing_logits[:,
                                                                        -1:])

        # the first four votes belong to the square
        is_from_capsule = winning_vote_idx // self._n_votes

        posterior_mixing_probs = tf.nn.softmax(
            posterior_mixing_logits_per_point, 1)

        dummy_vote = tf.get_variable('dummy_vote',
                                     shape=self._votes[:1, :1].shape)
        dummy_vote = snt.TileByDim([0], [batch_size])(dummy_vote)
        dummy_pres = tf.zeros([batch_size, 1, n_input_points])

        votes = tf.concat((self._votes, dummy_vote), 1)
        pres = tf.concat([self._vote_presence_prob, dummy_pres], 1)

        soft_winner = tf.reduce_sum(
            tf.expand_dims(posterior_mixing_probs, -1) * votes, 1)
        soft_winner_pres = tf.reduce_sum(posterior_mixing_probs * pres, 1)

        posterior_mixing_probs = tf.transpose(posterior_mixing_probs[:, :-1],
                                              (0, 2, 1))

        assert winning_vote.shape == x.shape

        return self.OutputTuple(
            log_prob=mixture_log_prob_per_batch,
            vote_presence=tf.to_float(vote_presence),
            winner=winning_vote,
            winner_pres=winning_pres,
            soft_winner=soft_winner,
            soft_winner_pres=soft_winner_pres,
            posterior_mixing_probs=posterior_mixing_probs,
            is_from_capsule=is_from_capsule,
            mixing_logits=mixing_logits,
            mixing_log_prob=mixing_log_prob,
        )
def model_fn(features, labels, mode, params, config):
    """Builds the model function for use in an estimator.

  Args:
    features: The input features for the estimator.
    labels: The labels, unused here.
    mode: Signifies whether it is train or test or predict.
    params: Some hyperparameters as a dictionary.
    config: The RunConfig, unused here.

  Returns:
    EstimatorSpec: A tf.estimator.EstimatorSpec instance.
  """
    del labels, config

    if params["analytic_kl"] and params["mixture_components"] != 1:
        raise NotImplementedError(
            "Using `analytic_kl` is only supported when `mixture_components = 1` "
            "since there's no closed form otherwise.")

    encoder = make_encoder(params["activation"], params["latent_size"],
                           params["base_depth"])
    decoder = make_decoder(params["activation"], params["latent_size"],
                           IMAGE_SHAPE, params["base_depth"])
    latent_prior = make_mixture_prior(params["latent_size"],
                                      params["mixture_components"])

    image_tile_summary("input",
                       tf.cast(features, dtype=tf.float32),
                       rows=1,
                       cols=16)

    approx_posterior = encoder(features)
    approx_posterior_sample = approx_posterior.sample(params["n_samples"])
    decoder_likelihood = decoder(approx_posterior_sample)
    image_tile_summary("recon/sample",
                       tf.cast(decoder_likelihood.sample()[:3, :16],
                               dtype=tf.float32),
                       rows=3,
                       cols=16)
    image_tile_summary("recon/mean",
                       decoder_likelihood.mean()[:3, :16],
                       rows=3,
                       cols=16)

    # `distortion` is just the negative log likelihood.
    distortion = -decoder_likelihood.log_prob(features)
    avg_distortion = tf.reduce_mean(input_tensor=distortion)
    tf.compat.v1.summary.scalar("distortion", avg_distortion)

    if params["analytic_kl"]:
        rate = tfd.kl_divergence(approx_posterior, latent_prior)
    else:
        rate = (approx_posterior.log_prob(approx_posterior_sample) -
                latent_prior.log_prob(approx_posterior_sample))
    avg_rate = tf.reduce_mean(input_tensor=rate)
    tf.compat.v1.summary.scalar("rate", avg_rate)

    elbo_local = -(rate + distortion)

    elbo = tf.reduce_mean(input_tensor=elbo_local)
    loss = -elbo
    tf.compat.v1.summary.scalar("elbo", elbo)

    importance_weighted_elbo = tf.reduce_mean(
        input_tensor=tf.reduce_logsumexp(input_tensor=elbo_local, axis=0) -
        tf.math.log(tf.cast(params["n_samples"], dtype=tf.float32)))
    tf.compat.v1.summary.scalar("elbo/importance_weighted",
                                importance_weighted_elbo)

    # Decode samples from the prior for visualization.
    random_image = decoder(latent_prior.sample(16))
    image_tile_summary("random/sample",
                       tf.cast(random_image.sample(), dtype=tf.float32),
                       rows=4,
                       cols=4)
    image_tile_summary("random/mean", random_image.mean(), rows=4, cols=4)

    # Perform variational inference by minimizing the -ELBO.
    global_step = tf.compat.v1.train.get_or_create_global_step()
    learning_rate = tf.compat.v1.train.cosine_decay(params["learning_rate"],
                                                    global_step,
                                                    params["max_steps"])
    tf.compat.v1.summary.scalar("learning_rate", learning_rate)
    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate)
    train_op = optimizer.minimize(loss, global_step=global_step)

    return tf.estimator.EstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op,
        eval_metric_ops={
            "elbo":
            tf.compat.v1.metrics.mean(elbo),
            "elbo/importance_weighted":
            tf.compat.v1.metrics.mean(importance_weighted_elbo),
            "rate":
            tf.compat.v1.metrics.mean(avg_rate),
            "distortion":
            tf.compat.v1.metrics.mean(avg_distortion),
        },
    )
    def __init__(self,
                 env,
                 policy,
                 context_encoder,
                 context_encoder_recurrent=False,
                 expert_trajs=None,
                 reward_arch=relu_net,
                 reward_arch_args=None,
                 value_fn_arch=relu_net,
                 score_discrim=False,
                 discount=1.0,
                 state_only=True,
                 max_path_length=500,
                 meta_batch_size=16,
                 max_itrs=100,
                 fusion=False,
                 latent_dim=3,
                 imitation_coeff=1.0,
                 info_coeff=1.0,
                 name='info_airl'):
        super(InfoAIRL, self).__init__()
        env_spec = env.spec
        if reward_arch_args is None:
            reward_arch_args = {}

        if fusion:
            self.fusion = RamFusionDistrCustom(100, subsample_ratio=0.5)
        else:
            self.fusion = None
        self.dO = env_spec.observation_space.flat_dim - latent_dim
        self.dU = env_spec.action_space.flat_dim
        assert isinstance(env.action_space, Box)
        self.context_encoder = context_encoder
        self.score_discrim = score_discrim
        self.gamma = discount
        assert value_fn_arch is not None
        self.set_demos(expert_trajs)
        self.state_only = state_only
        self.T = max_path_length
        self.max_itrs = max_itrs
        self.latent_dim = latent_dim
        self.meta_batch_size = meta_batch_size
        self.policy = policy

        # build energy model
        with tf.variable_scope(name) as _vs:
            # Should be meta_batch_size x batch_size x T x dO/dU
            self.expert_traj_var = tf.placeholder(
                tf.float32, [meta_batch_size, None, self.T, self.dO + self.dU],
                name='expert_traj')
            self.sample_traj_var = tf.placeholder(
                tf.float32, [meta_batch_size, None, self.T, self.dO + self.dU],
                name='sample_traj')
            self.obs_t = tf.placeholder(
                tf.float32, [meta_batch_size, None, self.T, self.dO],
                name='obs')
            self.nobs_t = tf.placeholder(
                tf.float32, [meta_batch_size, None, self.T, self.dO],
                name='nobs')
            self.act_t = tf.placeholder(
                tf.float32, [meta_batch_size, None, self.T, self.dU],
                name='act')
            self.nact_t = tf.placeholder(
                tf.float32, [meta_batch_size, None, self.T, self.dU],
                name='nact')
            self.labels = tf.placeholder(tf.float32,
                                         [meta_batch_size, None, 1, 1],
                                         name='labels')
            self.lprobs = tf.placeholder(tf.float32,
                                         [meta_batch_size, None, self.T, 1],
                                         name='log_probs')
            self.lr = tf.placeholder(tf.float32, (), name='lr')

            self.imitation_expert_obses = tf.placeholder(
                tf.float32,
                [meta_batch_size, None, self.T, self.dO])  # None = 1
            self.imitation_expert_acts = tf.placeholder(
                tf.float32, [meta_batch_size, None, self.T, self.dU])
            imitation_expert_obses = tf.reshape(self.imitation_expert_obses,
                                                [-1, self.dO])
            imitation_expert_acts = tf.reshape(self.imitation_expert_acts,
                                               [-1, self.dU])

            with tf.variable_scope('discrim') as dvs:
                # infer m_hat
                expert_traj_var = tf.reshape(
                    self.expert_traj_var, [-1, (self.dO + self.dU) * self.T])
                # m_hat should be of shape meta_batch_size x (batch_size*2) x T x latent_dim
                context_dist_info_vars = self.context_encoder.dist_info_sym(
                    expert_traj_var)
                context_mean_var = context_dist_info_vars["mean"]
                context_log_std_var = context_dist_info_vars["log_std"]
                eps = tf.random.normal(shape=tf.shape(context_mean_var))
                reparam_latent = eps * tf.exp(
                    context_log_std_var) + context_mean_var

                self.reparam_latent_tile = reparam_latent_tile = tf.tile(
                    tf.expand_dims(reparam_latent, axis=1), [1, self.T, 1])

                # One shot imitation
                self.imitation_reparam_latent_tile = tf.reshape(
                    tf.reshape(
                        self.reparam_latent_tile,
                        [meta_batch_size, -1, self.T, latent_dim])[:, 0, :, :],
                    [-1, latent_dim])
                concat_obses_batch = tf.concat([
                    imitation_expert_obses, self.imitation_reparam_latent_tile
                ],
                                               axis=1)
                policy_dist_info_vars = policy.dist_info_sym(
                    obs_var=concat_obses_batch)
                policy_likelihood_loss = -tf.reduce_mean(
                    policy.distribution.log_likelihood_sym(
                        imitation_expert_acts, policy_dist_info_vars))

                reparam_latent_tile = tf.reshape(reparam_latent_tile,
                                                 [-1, latent_dim])

                rew_input = self.obs_t
                if not self.state_only:
                    rew_input = tf.concat([self.obs_t, self.act_t], axis=-1)
                # condition on inferred m
                rew_input = tf.concat([
                    tf.reshape(rew_input,
                               [-1, rew_input.get_shape().dims[-1].value]),
                    reparam_latent_tile
                ],
                                      axis=1)
                with tf.variable_scope('reward'):
                    self.reward = reward_arch(rew_input,
                                              dout=1,
                                              **reward_arch_args)
                    self.sampled_traj_return = tf.reduce_sum(tf.reshape(
                        self.reward, [meta_batch_size, -1, self.T]),
                                                             axis=-1,
                                                             keepdims=True)
                # with tf.variable_scope('reward', reuse=True):
                #     self.sampled_traj_return = reward_arch(tf.reshape(self.sampled_traj_var, [-1, self.dO+self.dU]), dout=1, **reward_arch)
                #     self.sampled_traj_return = tf.reduce_sum(tf.reshape(self.sampled_traj_return, [meta_batch_size, -1, self.T]), axis=-1)
                #energy_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=vs.name)

                npotential_input = tf.concat([
                    tf.reshape(self.nobs_t, [-1, self.dO]), reparam_latent_tile
                ],
                                             axis=-1)
                potential_input = tf.concat([
                    tf.reshape(self.obs_t, [-1, self.dO]), reparam_latent_tile
                ],
                                            axis=-1)

                # value function shaping
                with tf.variable_scope('vfn'):
                    fitted_value_fn_n = value_fn_arch(npotential_input, dout=1)
                with tf.variable_scope('vfn', reuse=True):
                    self.value_fn = fitted_value_fn = value_fn_arch(
                        potential_input, dout=1)

                # Define log p_tau(a|s) = r + gamma * V(s') - V(s)
                self.qfn = self.reward + self.gamma * fitted_value_fn_n
                log_p_tau = self.reward + self.gamma * fitted_value_fn_n - fitted_value_fn

            log_q_tau = self.lprobs
            log_p_tau = tf.reshape(log_p_tau, [meta_batch_size, -1, self.T, 1])

            log_pq = tf.reduce_logsumexp(
                [log_p_tau, log_q_tau],
                axis=0)  # [meta_batch_size, -1, self.T, 1]
            self.discrim_output = tf.exp(log_p_tau - log_pq)
            cent_loss = -tf.reduce_mean(self.labels * (log_p_tau - log_pq) +
                                        (1 - self.labels) *
                                        (log_q_tau - log_pq))

            # compute mutual information loss
            # sampled_traj_var = tf.reshape(tf.concat([self.obs_t, self.act_t], axis=-1), [-1, (self.dO+self.dU)*self.T])
            log_q_m_tau = tf.reshape(
                self.context_encoder.distribution.log_likelihood_sym(
                    reparam_latent, context_dist_info_vars),
                [meta_batch_size, -1, 1])
            # Used for computing gradient w.r.t. psi
            info_loss = -tf.reduce_mean(log_q_m_tau *
                                        (1 - tf.squeeze(self.labels, axis=-1))
                                        ) / tf.reduce_mean(1 - self.labels)
            # Used for computing the gradient w.r.t. theta
            info_surr_loss = -tf.reduce_mean(
                (1 - tf.squeeze(self.labels, axis=-1)) * log_q_m_tau *
                self.sampled_traj_return -
                (1 - tf.squeeze(self.labels, axis=-1)) * log_q_m_tau *
                tf.reduce_mean(self.sampled_traj_return *
                               (1 - tf.squeeze(self.labels, axis=-1)),
                               axis=1,
                               keepdims=True) / tf.reduce_mean(1 - self.labels)
            ) / tf.reduce_mean(1 - self.labels)

            self.loss = cent_loss + info_coeff * info_loss
            self.info_loss = info_loss
            self.policy_likelihood_loss = policy_likelihood_loss
            tot_loss = self.loss
            context_encoder_weights = self.context_encoder.get_params(
                trainable=True)
            # reward_weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="reward")
            reward_weights = [
                i for i in tf.trainable_variables() if "reward" in i.name
            ]
            # value_fn_weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="vfn")
            value_fn_weights = [
                i for i in tf.trainable_variables() if "vfn" in i.name
            ]

            # self.step = tf.train.AdamOptimizer(learning_rate=self.lr).minimize(tot_loss)
            optimizer = tf.train.AdamOptimizer(learning_rate=self.lr)
            grads_and_vars_cent = optimizer.compute_gradients(
                cent_loss,
                var_list=reward_weights + value_fn_weights +
                context_encoder_weights)
            grads_and_vars_context = optimizer.compute_gradients(
                info_coeff * info_loss, var_list=context_encoder_weights)
            grads_and_vars_reward = optimizer.compute_gradients(
                info_coeff * info_surr_loss, var_list=reward_weights)
            grads_and_vars_policy = optimizer.compute_gradients(
                imitation_coeff * policy_likelihood_loss,
                var_list=self.policy.get_params(trainable=True) +
                context_encoder_weights)
            self.step = optimizer.apply_gradients(grads_and_vars_cent +
                                                  grads_and_vars_context +
                                                  grads_and_vars_reward +
                                                  grads_and_vars_policy)

            # grads_and_vars_cent = optimizer.compute_gradients(cent_loss, var_list=reward_weights+value_fn_weights)
            # grads_and_vars_reward = optimizer.compute_gradients(info_coeff*info_surr_loss, var_list=reward_weights)
            # self.step = optimizer.apply_gradients(grads_and_vars_cent+grads_and_vars_reward)

            self._make_param_ops(_vs)
Beispiel #7
0
    def _build(self, x, presence=None):

        batch_size, n_input_points = x.shape[:2].as_list()

        # we don't know what order the initial points came in, so we need to create
        # a big mixture of all votes for every input point
        # [B, 1, n_votes, n_input_dims]
        expanded_votes = tf.expand_dims(self._votes, 1)
        expanded_scale = tf.expand_dims(tf.expand_dims(self._scales, 1), -1)
        vote_component_pdf = self._get_pdf(expanded_votes, expanded_scale)

        # [B, n_points, n_caps, n_votes, n_input_dims]
        expanded_x = tf.expand_dims(x, 2)
        vote_log_prob_per_dim = vote_component_pdf.log_prob(expanded_x)
        # [B, n_points, n_votes]
        vote_log_prob = tf.reduce_sum(vote_log_prob_per_dim, -1)
        dummy_vote_log_prob = tf.zeros([batch_size, n_input_points, 1])
        dummy_vote_log_prob -= 2. * tf.log(10.)
        vote_log_prob = tf.concat([vote_log_prob, dummy_vote_log_prob], 2)

        # [B, n_points, n_votes]
        mixing_logits = math_ops.safe_log(self._vote_presence_prob)

        dummy_logit = tf.zeros([batch_size, 1]) - 2. * tf.log(10.)
        mixing_logits = tf.concat([mixing_logits, dummy_logit], 1)

        mixing_log_prob = mixing_logits - tf.reduce_logsumexp(
            mixing_logits, 1, keepdims=True)

        expanded_mixing_logits = tf.expand_dims(mixing_log_prob, 1)
        mixture_log_prob_per_component\
          = tf.reduce_logsumexp(expanded_mixing_logits + vote_log_prob, 2)

        if presence is not None:
            presence = tf.to_float(presence)
            mixture_log_prob_per_component *= presence

        mixture_log_prob_per_example\
          = tf.reduce_sum(mixture_log_prob_per_component, 1)

        mixture_log_prob_per_batch = tf.reduce_mean(
            mixture_log_prob_per_example)

        # [B, n_points, n_votes]
        posterior_mixing_logits_per_point = expanded_mixing_logits + vote_log_prob
        # [B, n_points]
        winning_vote_idx = tf.argmax(
            posterior_mixing_logits_per_point[:, :, :-1], 2)

        batch_idx = tf.expand_dims(tf.range(batch_size, dtype=tf.int64), -1)
        batch_idx = snt.TileByDim([1], [winning_vote_idx.shape[-1]])(batch_idx)

        idx = tf.stack([batch_idx, winning_vote_idx], -1)
        winning_vote = tf.gather_nd(self._votes, idx)
        winning_pres = tf.gather_nd(self._vote_presence_prob, idx)
        vote_presence = tf.greater(mixing_logits[:, :-1], mixing_logits[:,
                                                                        -1:])

        # the first four votes belong to the square
        is_from_capsule = winning_vote_idx // self._n_votes

        posterior_mixing_probs = tf.nn.softmax(
            posterior_mixing_logits_per_point, -1)[Ellipsis, :-1]

        assert winning_vote.shape == x.shape

        return self.OutputTuple(
            log_prob=mixture_log_prob_per_batch,
            vote_presence=tf.to_float(vote_presence),
            winner=winning_vote,
            winner_pres=winning_pres,
            is_from_capsule=is_from_capsule,
            mixing_logits=mixing_logits,
            mixing_log_prob=mixing_log_prob,
            # TODO(adamrk): this is broken
            soft_winner=tf.zeros_like(winning_vote),
            soft_winner_pres=tf.zeros_like(winning_pres),
            posterior_mixing_probs=posterior_mixing_probs,
        )
Beispiel #8
0
def model_fn(features, labels, mode, params):
  """Model function."""
  del labels

  # ==============================
  # Input features
  # ==============================
  # [batch_size, query_seq_len]
  query_inputs = features["query_inputs"]

  # [batch_size, num_candidates, candidate_seq_len]
  candidate_inputs = features["candidate_inputs"]

  # [batch_size, num_candidates, query_seq_len + candidate_seq_len]
  joint_inputs = features["joint_inputs"]

  # [batch_size, num_masks]
  mlm_targets = features["mlm_targets"]
  mlm_positions = features["mlm_positions"]
  mlm_mask = features["mlm_mask"]

  # ==============================
  # Create modules.
  # ==============================
  bert_module = hub.Module(
      spec=params["bert_hub_module_handle"],
      name="bert",
      tags={"train"} if mode == tf.estimator.ModeKeys.TRAIN else {},
      trainable=True)
  hub.register_module_for_export(bert_module, "bert")

  embedder_module = hub.Module(
      spec=params["embedder_hub_module_handle"],
      name="embedder",
      tags={"train"} if mode == tf.estimator.ModeKeys.TRAIN else {},
      trainable=True)
  hub.register_module_for_export(embedder_module, "embedder")

  if params["share_embedders"]:
    query_embedder_module = embedder_module
  else:
    query_embedder_module = hub.Module(
        spec=params["embedder_hub_module_handle"],
        name="embedder",
        tags={"train"} if mode == tf.estimator.ModeKeys.TRAIN else {},
        trainable=True)
    hub.register_module_for_export(embedder_module, "query_embedder")

  # ==============================
  # Retrieve.
  # ==============================
  # [batch_size, projected_size]
  query_emb = query_embedder_module(
      inputs=dict(
          input_ids=query_inputs.token_ids,
          input_mask=query_inputs.mask,
          segment_ids=query_inputs.segment_ids),
      signature="projected")

  # [batch_size * num_candidates, candidate_seq_len]
  flat_candidate_inputs, unflatten = flatten_bert_inputs(
      candidate_inputs)

  # [batch_size * num_candidates, projected_size]
  flat_candidate_emb = embedder_module(
      inputs=dict(
          input_ids=flat_candidate_inputs.token_ids,
          input_mask=flat_candidate_inputs.mask,
          segment_ids=flat_candidate_inputs.segment_ids),
      signature="projected")

  # [batch_size, num_candidates, projected_size]
  unflattened_candidate_emb = unflatten(flat_candidate_emb)

  # [batch_size, num_candidates]
  retrieval_score = tf.einsum("BD,BND->BN", query_emb,
                              unflattened_candidate_emb)

  # ==============================
  # Read.
  # ==============================
  # [batch_size * num_candidates, query_seq_len + candidate_seq_len]
  flat_joint_inputs, unflatten = flatten_bert_inputs(joint_inputs)

  # [batch_size * num_candidates, num_masks]
  flat_mlm_positions, _ = tensor_utils.flatten(
      tf.tile(
          tf.expand_dims(mlm_positions, 1), [1, params["num_candidates"], 1]))

  batch_size, num_masks = tensor_utils.shape(mlm_targets)

  # [batch_size * num_candidates, query_seq_len + candidates_seq_len]
  flat_joint_bert_outputs = bert_module(
      inputs=dict(
          input_ids=flat_joint_inputs.token_ids,
          input_mask=flat_joint_inputs.mask,
          segment_ids=flat_joint_inputs.segment_ids,
          mlm_positions=flat_mlm_positions),
      signature="mlm",
      as_dict=True)

  # [batch_size, num_candidates]
  candidate_score = retrieval_score

  # [batch_size, num_candidates]
  candidate_log_probs = tf.math.log_softmax(candidate_score)

  # ==============================
  # Compute marginal log-likelihood.
  # ==============================
  # [batch_size * num_candidates, num_masks]
  flat_mlm_logits = flat_joint_bert_outputs["mlm_logits"]

  # [batch_size, num_candidates, num_masks, vocab_size]
  mlm_logits = tf.reshape(
      flat_mlm_logits, [batch_size, params["num_candidates"], num_masks, -1])
  mlm_log_probs = tf.math.log_softmax(mlm_logits)

  # [batch_size, num_candidates, num_masks]
  tiled_mlm_targets = tf.tile(
      tf.expand_dims(mlm_targets, 1), [1, params["num_candidates"], 1])

  # [batch_size, num_candidates, num_masks, 1]
  tiled_mlm_targets = tf.expand_dims(tiled_mlm_targets, -1)

  # [batch_size, num_candidates, num_masks, 1]
  gold_log_probs = tf.batch_gather(mlm_log_probs, tiled_mlm_targets)

  # [batch_size, num_candidates, num_masks]
  gold_log_probs = tf.squeeze(gold_log_probs, -1)

  # [batch_size, num_candidates, num_masks]
  joint_gold_log_probs = (
      tf.expand_dims(candidate_log_probs, -1) + gold_log_probs)

  # [batch_size, num_masks]
  marginal_gold_log_probs = tf.reduce_logsumexp(joint_gold_log_probs, 1)

  # [batch_size, num_masks]
  float_mlm_mask = tf.cast(mlm_mask, tf.float32)

  # []
  loss = -tf.div_no_nan(
      tf.reduce_sum(marginal_gold_log_probs * float_mlm_mask),
      tf.reduce_sum(float_mlm_mask))

  # ==============================
  # Optimization
  # ==============================
  num_warmup_steps = min(10000, max(100, int(params["num_train_steps"] / 10)))
  train_op = optimization.create_optimizer(
      loss=loss,
      init_lr=params["learning_rate"],
      num_train_steps=params["num_train_steps"],
      num_warmup_steps=num_warmup_steps,
      use_tpu=params["use_tpu"])

  # ==============================
  # Evaluation
  # ==============================
  eval_metric_ops = None if params["use_tpu"] else dict()
  if mode != tf.estimator.ModeKeys.PREDICT:
    # [batch_size, num_masks]
    retrieval_utility = marginal_gold_log_probs - gold_log_probs[:, 0]
    retrieval_utility *= tf.cast(features["mlm_mask"], tf.float32)

    # []
    retrieval_utility = tf.div_no_nan(
        tf.reduce_sum(retrieval_utility), tf.reduce_sum(float_mlm_mask))
    add_mean_metric("retrieval_utility", retrieval_utility, eval_metric_ops)

    has_timestamp = tf.cast(
        tf.greater(features["export_timestamp"], 0), tf.float64)
    off_policy_delay_secs = (
        tf.timestamp() - tf.cast(features["export_timestamp"], tf.float64))
    off_policy_delay_mins = off_policy_delay_secs / 60.0
    off_policy_delay_mins *= tf.cast(has_timestamp, tf.float64)

    add_mean_metric("off_policy_delay_mins", off_policy_delay_mins,
                    eval_metric_ops)

  # Create empty predictions to avoid errors when running in prediction mode.
  predictions = dict()

  if params["use_tpu"]:
    return tf.estimator.tpu.TPUEstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op,
        predictions=predictions)
  else:
    if eval_metric_ops is not None:
      # Make sure the eval metrics are updated during training so that we get
      # quick feedback from tensorboard summaries when debugging locally.
      with tf.control_dependencies([u for _, u in eval_metric_ops.values()]):
        loss = tf.identity(loss)
    return tf.estimator.EstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op,
        eval_metric_ops=eval_metric_ops,
        predictions=predictions)
Beispiel #9
0
def iwae(p_z,
         p_x_given_z,
         q_z,
         observations,
         num_samples,
         cvs,
         contexts=None,
         antithetic=False):
    """Computes a gradient of the IWAE estimator.

  Args:
    p_z: The prior. Should be a callable that optionally accepts a conditioning
      context and returns a tfp.distributions.Distribution which has the
      log_prob and sample methods implemented. The distribution should be over a
      [batch_size, latent_dim] space.
    p_x_given_z: The likelihood. Should be a callable that accepts as input a
      tensor of shape [num_samples, batch_size, latent_size + context_size] and
      returns a tfd.Distribution over a [num_samples, batch_size, data_dim]
      space.
    q_z: The proposal, should be a callable which accepts a batch of
      observations of shape [batch_size, data_dim] and returns a distribution
      over [batch_size, latent_dim].
    observations: A float Tensor of shape [batch_size, data_dim] containing the
      observations.
    num_samples: The number of samples for the IWAE estimator.
    cvs: Control variate variables.
    contexts: A float Tensor of shape [batch_size, context_dim] containing the
      contexts. (Optionally, none)
    antithetic: Whether to use antithetic sampling.

  Returns:
    estimators: Dictionary of tuples (objective, neg_model_loss,
      neg_inference_network_loss).
  """
    alpha, beta, gamma, delta = cvs
    batch_size = tf.shape(observations)[0]
    proposal = q_z(observations, contexts, stop_gradient=False)
    # [num_samples, batch_size, latent_size]

    # If antithetic sampling, draw half of the samples and use the antithetics
    # for the other half.
    if antithetic:
        z_pos = proposal.sample(sample_shape=[num_samples // 2])
        z_neg = 2 * proposal.loc - z_pos
        z = tf.concat((z_pos, z_neg), axis=0)
    else:
        z = proposal.sample(sample_shape=[num_samples])

    tiled_contexts = None
    if contexts is not None:
        tiled_contexts = tf.tile(tf.expand_dims(contexts, 0),
                                 [num_samples, 1, 1])
    likelihood = p_x_given_z(z, tiled_contexts)
    # Before reduce_sum is [num_samples, batch_size, latent_dim].
    # Sum over the latent dim.
    log_q_z = tf.reduce_sum(proposal.log_prob(z), axis=-1)
    # Before reduce_sum is  [num_samples, batch_size, latent_dim].
    # Sum over latent dim.
    prior = p_z(contexts)
    log_p_z = tf.reduce_sum(prior.log_prob(z), axis=-1)
    # Before reduce_sum is [num_samples, batch_size, data_dim]
    log_p_x_given_z = tf.reduce_sum(likelihood.log_prob(observations), axis=-1)

    log_weights = log_p_z + log_p_x_given_z - log_q_z
    log_sum_weight = tf.reduce_logsumexp(log_weights, axis=0)
    log_avg_weight = log_sum_weight - tf.log(tf.to_float(num_samples))
    normalized_weights = tf.stop_gradient(tf.nn.softmax(log_weights, axis=0))

    if FLAGS.image_summary:
        best_index = tf.to_int32(tf.argmax(normalized_weights, axis=0))
        indices = tf.stack((best_index, tf.range(0, batch_size)), axis=-1)
        best_images = tf.gather_nd(likelihood.probs_parameter(), indices)

        if FLAGS.dataset == "struct_mnist":
            tf.summary.image("bottom_half",
                             tf.reshape(best_images, [batch_size, -1, 28, 1]))
        else:
            tf.summary.image("output",
                             tf.reshape(best_images, [batch_size, -1, 28, 1]))
        tf.summary.image("input",
                         tf.reshape(observations, [batch_size, -1, 28, 1]))

    # Compute gradient estimators
    model_loss = log_avg_weight
    estimators = {}

    estimators["iwae"] = (log_avg_weight, log_avg_weight, log_avg_weight)

    stopped_z_log_q_z = tf.reduce_sum(proposal.log_prob(tf.stop_gradient(z)),
                                      axis=-1)
    estimators["rws"] = (log_avg_weight, model_loss,
                         tf.reduce_sum(normalized_weights * stopped_z_log_q_z,
                                       axis=0))

    # Doubly reparameterized
    stopped_proposal = q_z(observations, contexts, stop_gradient=True)
    stopped_log_q_z = tf.reduce_sum(stopped_proposal.log_prob(z), axis=-1)
    stopped_log_weights = log_p_z + log_p_x_given_z - stopped_log_q_z
    sq_normalized_weights = tf.square(normalized_weights)

    estimators["stl"] = (log_avg_weight, model_loss,
                         tf.reduce_sum(normalized_weights *
                                       stopped_log_weights,
                                       axis=0))
    estimators["dreg"] = (log_avg_weight, model_loss,
                          tf.reduce_sum(sq_normalized_weights *
                                        stopped_log_weights,
                                        axis=0))
    estimators["rws-dreg"] = (
        log_avg_weight, model_loss,
        tf.reduce_sum(
            (normalized_weights - sq_normalized_weights) * stopped_log_weights,
            axis=0))

    # Add normed versions
    normalized_sq_normalized_weights = (
        sq_normalized_weights /
        tf.reduce_sum(sq_normalized_weights, axis=0, keepdims=True))
    estimators["dreg-norm"] = (log_avg_weight, model_loss,
                               tf.reduce_sum(normalized_sq_normalized_weights *
                                             stopped_log_weights,
                                             axis=0))

    rws_dregs_weights = normalized_weights - sq_normalized_weights
    normalized_rws_dregs_weights = rws_dregs_weights / tf.reduce_sum(
        rws_dregs_weights, axis=0, keepdims=True)
    estimators["rws-dreg-norm"] = (log_avg_weight, model_loss,
                                   tf.reduce_sum(normalized_rws_dregs_weights *
                                                 stopped_log_weights,
                                                 axis=0))

    estimators["dreg-alpha"] = (log_avg_weight, model_loss,
                                (1 - FLAGS.alpha) * estimators["dreg"][-1] +
                                FLAGS.alpha * estimators["rws-dreg"][-1])

    # Jackknife
    loo_log_weights = tf.tile(tf.expand_dims(tf.transpose(log_weights), -1),
                              [1, 1, num_samples])
    loo_log_weights = tf.matrix_set_diag(
        loo_log_weights, -np.inf * tf.ones([batch_size, num_samples]))
    loo_log_avg_weight = tf.reduce_mean(
        tf.reduce_logsumexp(loo_log_weights, axis=1) -
        tf.log(tf.to_float(num_samples - 1)),
        axis=-1)
    jk_model_loss = num_samples * log_avg_weight - (num_samples -
                                                    1) * loo_log_avg_weight

    estimators["jk"] = (jk_model_loss, jk_model_loss, jk_model_loss)

    # Compute JK w/ DReG for the inference network
    loo_normalized_weights = tf.reduce_mean(tf.square(
        tf.stop_gradient(tf.nn.softmax(loo_log_weights, axis=1))),
                                            axis=-1)
    estimators["jk-dreg"] = (
        jk_model_loss, jk_model_loss, num_samples *
        tf.reduce_sum(sq_normalized_weights * stopped_log_weights, axis=0) -
        (num_samples - 1) * tf.reduce_sum(
            tf.transpose(loo_normalized_weights) * stopped_log_weights, axis=0)
    )

    # Compute control variates
    loo_baseline = tf.expand_dims(tf.transpose(log_weights), -1)
    loo_baseline = tf.tile(loo_baseline, [1, 1, num_samples])
    loo_baseline = tf.matrix_set_diag(
        loo_baseline, -np.inf * tf.ones_like(tf.transpose(log_weights)))
    loo_baseline = tf.reduce_logsumexp(loo_baseline, axis=1)
    loo_baseline = tf.transpose(loo_baseline)

    learning_signal = tf.stop_gradient(tf.expand_dims(
        log_avg_weight, 0)) - (1 - gamma) * tf.stop_gradient(loo_baseline)
    vimco = tf.reduce_sum(learning_signal * stopped_z_log_q_z, axis=0)

    first_part = alpha * vimco + (1 - alpha) * tf.reduce_sum(
        normalized_weights * stopped_log_weights, axis=0)
    second_part = ((1 - beta) * (tf.reduce_sum(
        ((1 - delta) / tf.to_float(num_samples) - normalized_weights) *
        stopped_z_log_q_z,
        axis=0)) + beta * tf.reduce_sum(
            (sq_normalized_weights - normalized_weights) * stopped_log_weights,
            axis=0))
    estimators["dreg-cv"] = (log_avg_weight, model_loss,
                             first_part + second_part)

    return estimators
Beispiel #10
0
def _logmatmulexp(a, b):
    """`matmul` computed in log space."""
    return tf.reduce_logsumexp(a[..., :, :, None] + b[..., None, :, :],
                               axis=-2)
Beispiel #11
0
def marginal_log_loss(logits, is_correct):
    """Loss based on the negative marginal log-likelihood."""
    # []
    log_numerator = tf.reduce_logsumexp(logits + mask_to_score(is_correct), -1)
    log_denominator = tf.reduce_logsumexp(logits, -1)
    return log_denominator - log_numerator
Beispiel #12
0
 def log_prob(self, x):
     x = tf.expand_dims(x, 1)
     lp = self._component_log_prob(x)
     return tf.reduce_logsumexp(lp + self.mixing_log_prob, 1)
Beispiel #13
0
 def mixing_log_prob(self):
     return self._mixing_logits - tf.reduce_logsumexp(
         self._mixing_logits, 1, keepdims=True)
Beispiel #14
0
def model_fn(features, labels, mode, params, config):
    """Builds the model function for use in an estimator.

  Arguments:
    features: The input features for the estimator.
    labels: The labels, unused here.
    mode: Signifies whether it is train or test or predict.
    params: Some parameters, unused here.
    config: The RunConfig, unused here.

  Returns:
    EstimatorSpec: A tf.estimator.EstimatorSpec instance.
  """
    del labels, params, config

    if FLAGS.analytic_kl and FLAGS.mixture_components != 1:
        raise NotImplementedError(
            "Using `analytic_kl` is only supported when `mixture_components = 1` "
            "since there's no closed form otherwise.")
    if FLAGS.floating_prior and not (FLAGS.unit_posterior
                                     and FLAGS.mixture_components == 1):
        raise NotImplementedError(
            "Using `floating_prior` is only supported when `unit_posterior` = True "
            "since there's a scale ambiguity otherwise, and when "
            "`mixture_components = 1` since there's no closed form otherwise.")
    if FLAGS.fitted_samples and FLAGS.mixture_components != 1:
        raise NotImplementedError(
            "Using `fitted_samples` is only supported when "
            "`mixture_components = 1` since there's no closed form otherwise.")
    if FLAGS.bilbo and not FLAGS.floating_prior:
        raise NotImplementedError(
            "Using `bilbo` is only supported when `floating_prior = True`.")

    activation = tf.nn.leaky_relu
    encoder = make_encoder(activation, FLAGS.latent_size, FLAGS.base_depth)
    decoder = make_decoder(activation, FLAGS.latent_size,
                           [IMAGE_SIZE] * 2 + [3], FLAGS.base_depth)

    approx_posterior = encoder(features)
    approx_posterior_sample = approx_posterior.sample(FLAGS.n_samples)
    decoder_mu = decoder(approx_posterior_sample)

    if FLAGS.floating_prior or FLAGS.fitted_samples:
        posterior_batch_mean = tf.reduce_mean(approx_posterior.mean()**2, [0])
        posterior_batch_variance = tf.reduce_mean(approx_posterior.stddev()**2,
                                                  [0])
        posterior_scale = posterior_batch_mean + posterior_batch_variance
        floating_prior = tfd.MultivariateNormalDiag(
            tf.zeros(FLAGS.latent_size), tf.sqrt(posterior_scale))
        tf.summary.scalar("posterior_scale", tf.reduce_sum(posterior_scale))

    if FLAGS.floating_prior:
        latent_prior = floating_prior
    else:
        latent_prior = make_mixture_prior(FLAGS.latent_size,
                                          FLAGS.mixture_components)

    # Decode samples from the prior for visualization.
    if FLAGS.fitted_samples:
        sample_distribution = floating_prior
    else:
        sample_distribution = latent_prior

    n_samples = VIZ_GRID_SIZE**2
    random_mu = decoder(sample_distribution.sample(n_samples))

    residual = tf.reshape(features - decoder_mu, [-1] + [IMAGE_SIZE] * 2 + [3])

    if FLAGS.use_students_t:
        lossfun = adaptive.AdaptiveImageLossFunction(
            residual.shape[1:],
            residual.dtype,
            color_space=FLAGS.color_space,
            representation=FLAGS.representation,
            wavelet_num_levels=FLAGS.wavelet_num_levels,
            wavelet_scale_base=FLAGS.wavelet_scale_base,
            use_students_t=FLAGS.use_students_t,
            scale_lo=FLAGS.scale_lo,
            scale_init=FLAGS.scale_init)
    else:
        lossfun = adaptive.AdaptiveImageLossFunction(
            residual.shape[1:],
            residual.dtype,
            color_space=FLAGS.color_space,
            representation=FLAGS.representation,
            wavelet_num_levels=FLAGS.wavelet_num_levels,
            wavelet_scale_base=FLAGS.wavelet_scale_base,
            use_students_t=FLAGS.use_students_t,
            alpha_lo=FLAGS.alpha_lo,
            alpha_hi=FLAGS.alpha_hi,
            alpha_init=FLAGS.alpha_init,
            scale_lo=FLAGS.scale_lo,
            scale_init=FLAGS.scale_init)

    nll = lossfun(residual)

    nll = tf.reshape(nll, [tf.shape(decoder_mu)[0],
                           tf.shape(decoder_mu)[1]] + [IMAGE_SIZE] * 2 + [3])

    # Clipping to prevent the loss from nanning out.
    max_val = np.finfo(np.float32).max
    nll = tf.clip_by_value(nll, -max_val, max_val)

    viz_n_inputs = np.int32(np.minimum(VIZ_MAX_N_INPUTS, FLAGS.batch_size))
    viz_n_samples = np.int32(np.minimum(VIZ_MAX_N_SAMPLES, FLAGS.n_samples))

    image_tile_summary("input",
                       tf.to_float(features),
                       rows=1,
                       cols=viz_n_inputs)

    image_tile_summary("recon/mean",
                       decoder_mu[:viz_n_samples, :viz_n_inputs],
                       rows=viz_n_samples,
                       cols=viz_n_inputs)

    img_summary_input = image_tile_summary("input1",
                                           tf.to_float(features),
                                           rows=viz_n_inputs,
                                           cols=1)
    img_summary_recon = image_tile_summary("recon1",
                                           decoder_mu[:1, :viz_n_inputs],
                                           rows=viz_n_inputs,
                                           cols=1)

    image_tile_summary("random/mean",
                       random_mu,
                       rows=VIZ_GRID_SIZE,
                       cols=VIZ_GRID_SIZE)

    distortion = tf.reduce_sum(nll, axis=[2, 3, 4])

    avg_distortion = tf.reduce_mean(distortion)
    tf.summary.scalar("distortion", avg_distortion)

    if FLAGS.analytic_kl:
        rate = tfd.kl_divergence(approx_posterior, latent_prior)
    else:
        rate = (approx_posterior.log_prob(approx_posterior_sample) -
                latent_prior.log_prob(approx_posterior_sample))
    avg_rate = tf.reduce_mean(rate)
    tf.summary.scalar("rate", avg_rate)

    elbo_local = -(rate + distortion)

    elbo = tf.reduce_mean(elbo_local)
    tf.summary.scalar("elbo", elbo)

    if FLAGS.bilbo:
        bilbo = -0.5 * tf.reduce_sum(
            tf.log1p(posterior_batch_mean /
                     posterior_batch_variance)) - avg_distortion
        tf.summary.scalar("bilbo", bilbo)
        loss = -bilbo
    else:
        loss = -elbo

    importance_weighted_elbo = tf.reduce_mean(
        tf.reduce_logsumexp(elbo_local, axis=0) -
        tf.math.log(tf.to_float(FLAGS.n_samples)))
    tf.summary.scalar("elbo/importance_weighted", importance_weighted_elbo)

    # Perform variational inference by minimizing the -ELBO.
    global_step = tf.train.get_or_create_global_step()
    learning_rate = tf.train.cosine_decay(
        FLAGS.learning_rate,
        tf.maximum(tf.cast(0, tf.int64),
                   global_step - int(FLAGS.decay_start * FLAGS.max_steps)),
        int((1. - FLAGS.decay_start) * FLAGS.max_steps))
    tf.summary.scalar("learning_rate", learning_rate)
    optimizer = tf.train.AdamOptimizer(learning_rate)

    if mode == tf_estimator.ModeKeys.TRAIN:
        train_op = optimizer.minimize(loss, global_step=global_step)
    else:
        train_op = None

    eval_metric_ops = {}
    eval_metric_ops["elbo"] = tf.metrics.mean(elbo)
    eval_metric_ops["elbo/importance_weighted"] = tf.metrics.mean(
        importance_weighted_elbo)
    eval_metric_ops["rate"] = tf.metrics.mean(avg_rate)
    eval_metric_ops["distortion"] = tf.metrics.mean(avg_distortion)
    # This ugly hackery is necessary to get TF to visualize when running the
    # eval set, apparently.
    eval_metric_ops["img_summary_input"] = (img_summary_input, tf.no_op())
    eval_metric_ops["img_summary_recon"] = (img_summary_recon, tf.no_op())
    eval_metric_ops = {str(k): v for k, v in eval_metric_ops.items()}

    return tf_estimator.EstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op,
        eval_metric_ops=eval_metric_ops,
    )
    def compute_iw_marginal(self,
                            targets,
                            targets_mask,
                            decoder_self_attention_bias,
                            features,
                            n_samples,
                            reduce_mean=True,
                            **kwargs):
        hparams = self._hparams
        z_q, log_q_z, _ = self.sample_q(targets,
                                        targets_mask,
                                        decoder_self_attention_bias,
                                        n_samples=n_samples,
                                        temp=1.0,
                                        **kwargs)  # [K*B, L, C]
        iw_kwargs = {
            key: ops.prepare_for_iw(value, n_samples)
            for (key, value) in kwargs.items()
        }
        iw_targets_mask = ops.prepare_for_iw(targets_mask, n_samples)
        iw_decoder_self_attention_bias = (
            common_attention.attention_bias_ignore_padding(1.0 -
                                                           iw_targets_mask))
        iw_features = copy.copy(features)
        iw_features["targets"] = ops.prepare_for_iw(features["targets"],
                                                    n_samples)

        log_p_z_base, log_abs_det = self.compute_prior_log_prob(
            z_q,
            iw_targets_mask,
            iw_decoder_self_attention_bias,
            check_invertibility=False,
            **iw_kwargs)
        log_p_z = log_p_z_base + log_abs_det

        body_output = ops.decoder("decoder", z_q, hparams,
                                  iw_decoder_self_attention_bias, **iw_kwargs)
        logits = self.top(body_output, iw_features)
        numerator, denominator = self.loss_iw(logits, iw_features)
        numerator = tf.reduce_sum(numerator[..., 0, 0], 1)  # [K*B]
        denominator = tf.reduce_sum(denominator[..., 0, 0], 1)  # [K*B]
        log_p_x = -1 * numerator / denominator
        log_q_z = gops.reduce_mean_over_l_sum_over_c(log_q_z, iw_targets_mask)
        log_p_z = log_p_z / tf.reduce_sum(iw_targets_mask, 1)

        log_p_x, log_q_z, log_p_z = [
            ops.unprepare_for_iw(ii, n_samples)
            for ii in [log_p_x, log_q_z, log_p_z]
        ]

        log_w_n = log_p_z - log_q_z
        log_w_n = tf.nn.log_softmax(log_w_n, axis=0)  # [K, B]

        iw_marginal = log_p_x + log_w_n
        iw_marginal = tf.reduce_logsumexp(iw_marginal, 0)  # [B]

        if reduce_mean:
            iw_marginal = tf.cast(tf.reduce_mean(iw_marginal, 0),
                                  tf.float32)  # [1]
        else:
            iw_marginal = tf.cast(iw_marginal, tf.float32)  # [1]
        return iw_marginal
def model_uncertainty(logits):
    """Mutual information between the categorical label and the model parameters.

  A way to evaluate uncertainty in ensemble models is to measure its spread or
  `disagreement`. One way is to measure the  mutual information between the
  categorical label and the parameters of the categorical output. This assesses
  uncertainty in predictions due to `model uncertainty`. Model
  uncertainty can be expressed as the difference of the total uncertainty and
  the expected data uncertainty:
  `Model uncertainty = Total uncertainty - Expected data uncertainty`, where

  * `Total uncertainty`: Entropy of expected predictive distribution.
  * `Expected data uncertainty`: Expected entropy of individual predictive
    distribution.

  This formulation was given by [1, 2] and allows the decomposition of total
  uncertainty into model uncertainty and expected data uncertainty. The
  total uncertainty will be high whenever the model is uncertain. However, the
  model uncertainty, the difference between total and expected data
  uncertainty, will be non-zero iff the ensemble disagrees.

  ## References:
  [1] Depeweg, S., Hernandez-Lobato, J. M., Doshi-Velez, F, and Udluft, S.
      Decomposition of uncertainty for active learning and reliable
      reinforcement learning in stochastic systems.
      stat 1050, p.11, 2017.
  [2] Malinin, A., Mlodozeniec, B., and Gales, M.
      Ensemble Distribution Distillation.
      arXiv:1905.00076, 2019.

  Args:
    logits: Tensor, shape (N, k, nc). Logits for N instances, k ensembles and
      nc classes.

  Raises:
    TypeError: Raised if both logits and probabilities are not set or both are
      set.
    ValueError: Raised if logits or probabilities do not conform to expected
      shape.

  Returns:
    model_uncertainty: Tensor, shape (N,).
    total_uncertainty: Tensor, shape (N,).
    expected_data_uncertainty: Tensor, shape (N,).
  """

    if logits is None:
        raise TypeError("model_uncertainty expected logits to be set.")
    if tf.rank(logits).numpy() != 3:
        raise ValueError(
            "model_uncertainty expected logits to be of shape (N, k, nc),"
            "instead got {}".format(logits.shape))

    # expected data uncertainty
    log_prob = tf.math.log_softmax(logits, -1)
    prob = tf.exp(log_prob)
    expected_data_uncertainty = tf.reduce_mean(
        tf.reduce_sum(-prob * log_prob, -1), -1)

    n_ens = tf.cast(log_prob.shape[1], tf.float32)
    log_expected_probabilities = tf.reduce_logsumexp(log_prob,
                                                     1) - tf.math.log(n_ens)
    expected_probabilities = tf.exp(log_expected_probabilities)
    total_uncertainty = tf.reduce_sum(
        -expected_probabilities * log_expected_probabilities, -1)

    model_uncertainty_ = total_uncertainty - expected_data_uncertainty

    return model_uncertainty_, total_uncertainty, expected_data_uncertainty
Beispiel #17
0
def _log_prob_from_logits(logits):
    return logits - tf.reduce_logsumexp(logits, axis=2, keep_dims=True)
Beispiel #18
0
    def _build_train_op(self):
        """Builds a training op.

        Returns:
          train_op: An op performing one step of training.
        """
        target_distribution = tf.stop_gradient(
            self._build_target_distribution())

        # size of indices: batch_size x 1.
        indices = tf.range(tf.shape(self._replay_net_outputs.logits)[0])[:,
                                                                         None]
        # size of reshaped_actions: batch_size x 2.
        reshaped_actions = tf.concat([indices, self._replay.actions[:, None]],
                                     1)
        # For each element of the batch, fetch the logits for its selected action.
        chosen_action_logits = tf.gather_nd(self._replay_net_outputs.logits,
                                            reshaped_actions)

        bellman_errors = (target_distribution[:, None, :] -
                          chosen_action_logits[:, :, None]
                          )  # Input `u' of Eq. 9.
        huber_loss = tf.to_float(  # Eq. 9 of paper.
            tf.abs(bellman_errors) <=
            self.kappa) * 0.5 * bellman_errors**2 + tf.to_float(
                tf.abs(bellman_errors) > self.kappa) * self.kappa * (
                    tf.abs(bellman_errors) - 0.5 * self.kappa)

        tau_hat = (
            tf.range(self._num_atoms, dtype=tf.float32) + 0.5
        ) / self._num_atoms  # Quantile midpoints.  See Lemma 2 of paper.

        quantile_huber_loss = (  # Eq. 10 of paper.
            tf.abs(tau_hat[None, :, None] - tf.to_float(bellman_errors < 0)) *
            huber_loss)

        # Sum over tau dimension, average over target value dimension.
        loss = tf.reduce_sum(tf.reduce_mean(quantile_huber_loss, 2), 1)

        if self._replay_scheme == "prioritized":
            target_priorities = self._replay.tf_get_priority(
                self._replay.indices)
            # The original prioritized experience replay uses a linear exponent
            # schedule 0.4 -> 1.0. Comparing the schedule to a fixed exponent of 0.5
            # on 5 games (Asterix, Pong, Q*Bert, Seaquest, Space Invaders) suggested
            # a fixed exponent actually performs better, except on Pong.
            loss_weights = 1.0 / tf.sqrt(target_priorities + 1e-10)
            loss_weights /= tf.reduce_max(loss_weights)

            # Rainbow and prioritized replay are parametrized by an exponent alpha,
            # but in both cases it is set to 0.5 - for simplicity's sake we leave it
            # as is here, using the more direct tf.sqrt(). Taking the square root
            # "makes sense", as we are dealing with a squared loss.
            # Add a small nonzero value to the loss to avoid 0 priority items. While
            # technically this may be okay, setting all items to 0 priority will cause
            # troubles, and also result in 1.0 / 0.0 = NaN correction terms.
            update_priorities_op = self._replay.tf_set_priority(
                self._replay.indices, tf.sqrt(loss + 1e-10))

            # Weight loss by inverse priorities.
            loss = loss_weights * loss
        else:
            update_priorities_op = tf.no_op()

        ### Add the CQL  loss
        replay_action_one_hot = tf.one_hot(self._replay.actions,
                                           self.num_actions,
                                           1.0,
                                           0.0,
                                           name="action_one_hot")
        replay_chosen_q = tf.reduce_sum(
            self._replay_net_outputs.q_values * replay_action_one_hot,
            reduction_indices=1,
            name="replay_chosen_q",
        )
        dataset_expec = tf.reduce_mean(replay_chosen_q)
        negative_sampling = tf.reduce_mean(
            tf.reduce_logsumexp(self._replay_net_outputs.q_values, 1))

        min_q_loss = negative_sampling - dataset_expec

        print("MIN Q WEIGHT: ", self.minq_weight)

        with tf.control_dependencies([update_priorities_op]):
            if self.summary_writer is not None:
                with tf.variable_scope("Losses"):
                    tf.summary.scalar("QuantileLoss", tf.reduce_mean(loss))
                    tf.summary.scalar("minQLoss", tf.reduce_mean(min_q_loss))
                    tf.summary.scalar("Q_predictions",
                                      tf.reduce_mean(replay_chosen_q))

            min_q_loss = min_q_loss * self.minq_weight
            return self.optimizer.minimize(tf.reduce_mean(loss) +
                                           min_q_loss), loss