def _get_mdn_coef(output): logmix, mean, logstd = tf.split(output, 3, -1) logmix = logmix - tf.reduce_logsumexp(logmix, -1, keepdims=True) return logmix, mean, logstd
def softmax_loss(self, antecedent_scores, antecedent_labels): gold_scores = antecedent_scores + \ tf.log(tf.to_float(antecedent_labels)) # [k, max_ant + 1] marginalized_gold_scores = tf.reduce_logsumexp(gold_scores, [1]) # [k] log_norm = tf.reduce_logsumexp(antecedent_scores, [1]) # [k] return log_norm - marginalized_gold_scores # [k]
def __init__(self, env_spec, expert_trajs=None, discrim_arch=feedforward_energy, discrim_arch_args={}, l2_reg=0, discount=1.0, init_itrs=None, score_dtau=False, state_only=False, name='trajprior'): super(GAN_GCL, self).__init__() self.dO = env_spec.observation_space.flat_dim self.dU = env_spec.action_space.flat_dim self.score_dtau = score_dtau self.set_demos(expert_trajs) # build energy model with tf.variable_scope(name) as vs: # Should be batch_size x T x dO/dU self.obs_t = tf.placeholder(tf.float32, [None, None, self.dO], name='obs') self.act_t = tf.placeholder(tf.float32, [None, None, self.dU], name='act') self.traj_logprobs = tf.placeholder(tf.float32, [None, None], name='traj_probs') self.labels = tf.placeholder(tf.float32, [None, 1], name='labels') self.lr = tf.placeholder(tf.float32, (), name='lr') if state_only: obs_act = self.obs_t else: obs_act = tf.concat([self.obs_t, self.act_t], axis=2) with tf.variable_scope('discrim') as vs2: self.energy = discrim_arch(obs_act, **discrim_arch_args) discrim_vars = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope=vs2.name) self.energy_timestep = self.energy # Don't train separate log Z because we can't fully separate it from the energy function if discount >= 1.0: log_p_tau = tf.reduce_sum(-self.energy, axis=1) else: log_p_tau = discounted_reduce_sum(-self.energy, discount=discount, axis=1) log_q_tau = tf.reduce_sum(self.traj_logprobs, axis=1, keep_dims=True) # numerical stability trick log_pq = tf.reduce_logsumexp([log_p_tau, log_q_tau], axis=0) self.d_tau = tf.exp(log_p_tau - log_pq) cent_loss = -tf.reduce_mean(self.labels * (log_p_tau - log_pq) + (1 - self.labels) * (log_q_tau - log_pq)) if l2_reg > 0: reg_loss = l2_reg * tf.reduce_sum( [tf.reduce_sum(tf.square(var)) for var in discrim_vars]) else: reg_loss = 0 #self.predictions = tf.nn.sigmoid(logits) self.loss = cent_loss + reg_loss self.step = tf.train.AdamOptimizer(learning_rate=self.lr).minimize( self.loss) self._make_param_ops(vs)
def _build(self, x, presence=None): # x is [B, n_input_points, n_input_dims] batch_size, n_input_points = x.shape[:2].as_list() # votes and scale have shape [B, n_caps, n_input_points, n_input_dims|1] # since scale is a per-caps scalar and we have one vote per capsule vote_component_pdf = self._get_pdf(self._votes, tf.expand_dims(self._scales, -1)) # expand along caps dimensions -> [B, 1, n_input_points, n_input_dims] expanded_x = tf.expand_dims(x, 1) vote_log_prob_per_dim = vote_component_pdf.log_prob(expanded_x) # [B, n_caps, n_input_points] vote_log_prob = tf.reduce_sum(vote_log_prob_per_dim, -1) dummy_vote_log_prob = tf.zeros([batch_size, 1, n_input_points]) dummy_vote_log_prob -= 2. * tf.log(10.) # [B, n_caps + 1, n_input_points] vote_log_prob = tf.concat([vote_log_prob, dummy_vote_log_prob], 1) # [B, n_caps, n_input_points] mixing_logits = math_ops.safe_log(self._vote_presence_prob) dummy_logit = tf.zeros([batch_size, 1, 1]) - 2. * tf.log(10.) dummy_logit = snt.TileByDim([2], [n_input_points])(dummy_logit) # [B, n_caps + 1, n_input_points] mixing_logits = tf.concat([mixing_logits, dummy_logit], 1) mixing_log_prob = mixing_logits - tf.reduce_logsumexp( mixing_logits, 1, keepdims=True) # [B, n_input_points] mixture_log_prob_per_point = tf.reduce_logsumexp( mixing_logits + vote_log_prob, 1) if presence is not None: presence = tf.to_float(presence) mixture_log_prob_per_point *= presence # [B,] mixture_log_prob_per_example\ = tf.reduce_sum(mixture_log_prob_per_point, 1) # [] mixture_log_prob_per_batch = tf.reduce_mean( mixture_log_prob_per_example) # [B, n_caps + 1, n_input_points] posterior_mixing_logits_per_point = mixing_logits + vote_log_prob # [B, n_input_points] winning_vote_idx = tf.argmax(posterior_mixing_logits_per_point[:, :-1], 1) batch_idx = tf.expand_dims(tf.range(batch_size, dtype=tf.int64), 1) batch_idx = snt.TileByDim([1], [n_input_points])(batch_idx) point_idx = tf.expand_dims(tf.range(n_input_points, dtype=tf.int64), 0) point_idx = snt.TileByDim([0], [batch_size])(point_idx) idx = tf.stack([batch_idx, winning_vote_idx, point_idx], -1) winning_vote = tf.gather_nd(self._votes, idx) winning_pres = tf.gather_nd(self._vote_presence_prob, idx) vote_presence = tf.greater(mixing_logits[:, :-1], mixing_logits[:, -1:]) # the first four votes belong to the square is_from_capsule = winning_vote_idx // self._n_votes posterior_mixing_probs = tf.nn.softmax( posterior_mixing_logits_per_point, 1) dummy_vote = tf.get_variable('dummy_vote', shape=self._votes[:1, :1].shape) dummy_vote = snt.TileByDim([0], [batch_size])(dummy_vote) dummy_pres = tf.zeros([batch_size, 1, n_input_points]) votes = tf.concat((self._votes, dummy_vote), 1) pres = tf.concat([self._vote_presence_prob, dummy_pres], 1) soft_winner = tf.reduce_sum( tf.expand_dims(posterior_mixing_probs, -1) * votes, 1) soft_winner_pres = tf.reduce_sum(posterior_mixing_probs * pres, 1) posterior_mixing_probs = tf.transpose(posterior_mixing_probs[:, :-1], (0, 2, 1)) assert winning_vote.shape == x.shape return self.OutputTuple( log_prob=mixture_log_prob_per_batch, vote_presence=tf.to_float(vote_presence), winner=winning_vote, winner_pres=winning_pres, soft_winner=soft_winner, soft_winner_pres=soft_winner_pres, posterior_mixing_probs=posterior_mixing_probs, is_from_capsule=is_from_capsule, mixing_logits=mixing_logits, mixing_log_prob=mixing_log_prob, )
def model_fn(features, labels, mode, params, config): """Builds the model function for use in an estimator. Args: features: The input features for the estimator. labels: The labels, unused here. mode: Signifies whether it is train or test or predict. params: Some hyperparameters as a dictionary. config: The RunConfig, unused here. Returns: EstimatorSpec: A tf.estimator.EstimatorSpec instance. """ del labels, config if params["analytic_kl"] and params["mixture_components"] != 1: raise NotImplementedError( "Using `analytic_kl` is only supported when `mixture_components = 1` " "since there's no closed form otherwise.") encoder = make_encoder(params["activation"], params["latent_size"], params["base_depth"]) decoder = make_decoder(params["activation"], params["latent_size"], IMAGE_SHAPE, params["base_depth"]) latent_prior = make_mixture_prior(params["latent_size"], params["mixture_components"]) image_tile_summary("input", tf.cast(features, dtype=tf.float32), rows=1, cols=16) approx_posterior = encoder(features) approx_posterior_sample = approx_posterior.sample(params["n_samples"]) decoder_likelihood = decoder(approx_posterior_sample) image_tile_summary("recon/sample", tf.cast(decoder_likelihood.sample()[:3, :16], dtype=tf.float32), rows=3, cols=16) image_tile_summary("recon/mean", decoder_likelihood.mean()[:3, :16], rows=3, cols=16) # `distortion` is just the negative log likelihood. distortion = -decoder_likelihood.log_prob(features) avg_distortion = tf.reduce_mean(input_tensor=distortion) tf.compat.v1.summary.scalar("distortion", avg_distortion) if params["analytic_kl"]: rate = tfd.kl_divergence(approx_posterior, latent_prior) else: rate = (approx_posterior.log_prob(approx_posterior_sample) - latent_prior.log_prob(approx_posterior_sample)) avg_rate = tf.reduce_mean(input_tensor=rate) tf.compat.v1.summary.scalar("rate", avg_rate) elbo_local = -(rate + distortion) elbo = tf.reduce_mean(input_tensor=elbo_local) loss = -elbo tf.compat.v1.summary.scalar("elbo", elbo) importance_weighted_elbo = tf.reduce_mean( input_tensor=tf.reduce_logsumexp(input_tensor=elbo_local, axis=0) - tf.math.log(tf.cast(params["n_samples"], dtype=tf.float32))) tf.compat.v1.summary.scalar("elbo/importance_weighted", importance_weighted_elbo) # Decode samples from the prior for visualization. random_image = decoder(latent_prior.sample(16)) image_tile_summary("random/sample", tf.cast(random_image.sample(), dtype=tf.float32), rows=4, cols=4) image_tile_summary("random/mean", random_image.mean(), rows=4, cols=4) # Perform variational inference by minimizing the -ELBO. global_step = tf.compat.v1.train.get_or_create_global_step() learning_rate = tf.compat.v1.train.cosine_decay(params["learning_rate"], global_step, params["max_steps"]) tf.compat.v1.summary.scalar("learning_rate", learning_rate) optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate) train_op = optimizer.minimize(loss, global_step=global_step) return tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, eval_metric_ops={ "elbo": tf.compat.v1.metrics.mean(elbo), "elbo/importance_weighted": tf.compat.v1.metrics.mean(importance_weighted_elbo), "rate": tf.compat.v1.metrics.mean(avg_rate), "distortion": tf.compat.v1.metrics.mean(avg_distortion), }, )
def __init__(self, env, policy, context_encoder, context_encoder_recurrent=False, expert_trajs=None, reward_arch=relu_net, reward_arch_args=None, value_fn_arch=relu_net, score_discrim=False, discount=1.0, state_only=True, max_path_length=500, meta_batch_size=16, max_itrs=100, fusion=False, latent_dim=3, imitation_coeff=1.0, info_coeff=1.0, name='info_airl'): super(InfoAIRL, self).__init__() env_spec = env.spec if reward_arch_args is None: reward_arch_args = {} if fusion: self.fusion = RamFusionDistrCustom(100, subsample_ratio=0.5) else: self.fusion = None self.dO = env_spec.observation_space.flat_dim - latent_dim self.dU = env_spec.action_space.flat_dim assert isinstance(env.action_space, Box) self.context_encoder = context_encoder self.score_discrim = score_discrim self.gamma = discount assert value_fn_arch is not None self.set_demos(expert_trajs) self.state_only = state_only self.T = max_path_length self.max_itrs = max_itrs self.latent_dim = latent_dim self.meta_batch_size = meta_batch_size self.policy = policy # build energy model with tf.variable_scope(name) as _vs: # Should be meta_batch_size x batch_size x T x dO/dU self.expert_traj_var = tf.placeholder( tf.float32, [meta_batch_size, None, self.T, self.dO + self.dU], name='expert_traj') self.sample_traj_var = tf.placeholder( tf.float32, [meta_batch_size, None, self.T, self.dO + self.dU], name='sample_traj') self.obs_t = tf.placeholder( tf.float32, [meta_batch_size, None, self.T, self.dO], name='obs') self.nobs_t = tf.placeholder( tf.float32, [meta_batch_size, None, self.T, self.dO], name='nobs') self.act_t = tf.placeholder( tf.float32, [meta_batch_size, None, self.T, self.dU], name='act') self.nact_t = tf.placeholder( tf.float32, [meta_batch_size, None, self.T, self.dU], name='nact') self.labels = tf.placeholder(tf.float32, [meta_batch_size, None, 1, 1], name='labels') self.lprobs = tf.placeholder(tf.float32, [meta_batch_size, None, self.T, 1], name='log_probs') self.lr = tf.placeholder(tf.float32, (), name='lr') self.imitation_expert_obses = tf.placeholder( tf.float32, [meta_batch_size, None, self.T, self.dO]) # None = 1 self.imitation_expert_acts = tf.placeholder( tf.float32, [meta_batch_size, None, self.T, self.dU]) imitation_expert_obses = tf.reshape(self.imitation_expert_obses, [-1, self.dO]) imitation_expert_acts = tf.reshape(self.imitation_expert_acts, [-1, self.dU]) with tf.variable_scope('discrim') as dvs: # infer m_hat expert_traj_var = tf.reshape( self.expert_traj_var, [-1, (self.dO + self.dU) * self.T]) # m_hat should be of shape meta_batch_size x (batch_size*2) x T x latent_dim context_dist_info_vars = self.context_encoder.dist_info_sym( expert_traj_var) context_mean_var = context_dist_info_vars["mean"] context_log_std_var = context_dist_info_vars["log_std"] eps = tf.random.normal(shape=tf.shape(context_mean_var)) reparam_latent = eps * tf.exp( context_log_std_var) + context_mean_var self.reparam_latent_tile = reparam_latent_tile = tf.tile( tf.expand_dims(reparam_latent, axis=1), [1, self.T, 1]) # One shot imitation self.imitation_reparam_latent_tile = tf.reshape( tf.reshape( self.reparam_latent_tile, [meta_batch_size, -1, self.T, latent_dim])[:, 0, :, :], [-1, latent_dim]) concat_obses_batch = tf.concat([ imitation_expert_obses, self.imitation_reparam_latent_tile ], axis=1) policy_dist_info_vars = policy.dist_info_sym( obs_var=concat_obses_batch) policy_likelihood_loss = -tf.reduce_mean( policy.distribution.log_likelihood_sym( imitation_expert_acts, policy_dist_info_vars)) reparam_latent_tile = tf.reshape(reparam_latent_tile, [-1, latent_dim]) rew_input = self.obs_t if not self.state_only: rew_input = tf.concat([self.obs_t, self.act_t], axis=-1) # condition on inferred m rew_input = tf.concat([ tf.reshape(rew_input, [-1, rew_input.get_shape().dims[-1].value]), reparam_latent_tile ], axis=1) with tf.variable_scope('reward'): self.reward = reward_arch(rew_input, dout=1, **reward_arch_args) self.sampled_traj_return = tf.reduce_sum(tf.reshape( self.reward, [meta_batch_size, -1, self.T]), axis=-1, keepdims=True) # with tf.variable_scope('reward', reuse=True): # self.sampled_traj_return = reward_arch(tf.reshape(self.sampled_traj_var, [-1, self.dO+self.dU]), dout=1, **reward_arch) # self.sampled_traj_return = tf.reduce_sum(tf.reshape(self.sampled_traj_return, [meta_batch_size, -1, self.T]), axis=-1) #energy_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=vs.name) npotential_input = tf.concat([ tf.reshape(self.nobs_t, [-1, self.dO]), reparam_latent_tile ], axis=-1) potential_input = tf.concat([ tf.reshape(self.obs_t, [-1, self.dO]), reparam_latent_tile ], axis=-1) # value function shaping with tf.variable_scope('vfn'): fitted_value_fn_n = value_fn_arch(npotential_input, dout=1) with tf.variable_scope('vfn', reuse=True): self.value_fn = fitted_value_fn = value_fn_arch( potential_input, dout=1) # Define log p_tau(a|s) = r + gamma * V(s') - V(s) self.qfn = self.reward + self.gamma * fitted_value_fn_n log_p_tau = self.reward + self.gamma * fitted_value_fn_n - fitted_value_fn log_q_tau = self.lprobs log_p_tau = tf.reshape(log_p_tau, [meta_batch_size, -1, self.T, 1]) log_pq = tf.reduce_logsumexp( [log_p_tau, log_q_tau], axis=0) # [meta_batch_size, -1, self.T, 1] self.discrim_output = tf.exp(log_p_tau - log_pq) cent_loss = -tf.reduce_mean(self.labels * (log_p_tau - log_pq) + (1 - self.labels) * (log_q_tau - log_pq)) # compute mutual information loss # sampled_traj_var = tf.reshape(tf.concat([self.obs_t, self.act_t], axis=-1), [-1, (self.dO+self.dU)*self.T]) log_q_m_tau = tf.reshape( self.context_encoder.distribution.log_likelihood_sym( reparam_latent, context_dist_info_vars), [meta_batch_size, -1, 1]) # Used for computing gradient w.r.t. psi info_loss = -tf.reduce_mean(log_q_m_tau * (1 - tf.squeeze(self.labels, axis=-1)) ) / tf.reduce_mean(1 - self.labels) # Used for computing the gradient w.r.t. theta info_surr_loss = -tf.reduce_mean( (1 - tf.squeeze(self.labels, axis=-1)) * log_q_m_tau * self.sampled_traj_return - (1 - tf.squeeze(self.labels, axis=-1)) * log_q_m_tau * tf.reduce_mean(self.sampled_traj_return * (1 - tf.squeeze(self.labels, axis=-1)), axis=1, keepdims=True) / tf.reduce_mean(1 - self.labels) ) / tf.reduce_mean(1 - self.labels) self.loss = cent_loss + info_coeff * info_loss self.info_loss = info_loss self.policy_likelihood_loss = policy_likelihood_loss tot_loss = self.loss context_encoder_weights = self.context_encoder.get_params( trainable=True) # reward_weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="reward") reward_weights = [ i for i in tf.trainable_variables() if "reward" in i.name ] # value_fn_weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="vfn") value_fn_weights = [ i for i in tf.trainable_variables() if "vfn" in i.name ] # self.step = tf.train.AdamOptimizer(learning_rate=self.lr).minimize(tot_loss) optimizer = tf.train.AdamOptimizer(learning_rate=self.lr) grads_and_vars_cent = optimizer.compute_gradients( cent_loss, var_list=reward_weights + value_fn_weights + context_encoder_weights) grads_and_vars_context = optimizer.compute_gradients( info_coeff * info_loss, var_list=context_encoder_weights) grads_and_vars_reward = optimizer.compute_gradients( info_coeff * info_surr_loss, var_list=reward_weights) grads_and_vars_policy = optimizer.compute_gradients( imitation_coeff * policy_likelihood_loss, var_list=self.policy.get_params(trainable=True) + context_encoder_weights) self.step = optimizer.apply_gradients(grads_and_vars_cent + grads_and_vars_context + grads_and_vars_reward + grads_and_vars_policy) # grads_and_vars_cent = optimizer.compute_gradients(cent_loss, var_list=reward_weights+value_fn_weights) # grads_and_vars_reward = optimizer.compute_gradients(info_coeff*info_surr_loss, var_list=reward_weights) # self.step = optimizer.apply_gradients(grads_and_vars_cent+grads_and_vars_reward) self._make_param_ops(_vs)
def _build(self, x, presence=None): batch_size, n_input_points = x.shape[:2].as_list() # we don't know what order the initial points came in, so we need to create # a big mixture of all votes for every input point # [B, 1, n_votes, n_input_dims] expanded_votes = tf.expand_dims(self._votes, 1) expanded_scale = tf.expand_dims(tf.expand_dims(self._scales, 1), -1) vote_component_pdf = self._get_pdf(expanded_votes, expanded_scale) # [B, n_points, n_caps, n_votes, n_input_dims] expanded_x = tf.expand_dims(x, 2) vote_log_prob_per_dim = vote_component_pdf.log_prob(expanded_x) # [B, n_points, n_votes] vote_log_prob = tf.reduce_sum(vote_log_prob_per_dim, -1) dummy_vote_log_prob = tf.zeros([batch_size, n_input_points, 1]) dummy_vote_log_prob -= 2. * tf.log(10.) vote_log_prob = tf.concat([vote_log_prob, dummy_vote_log_prob], 2) # [B, n_points, n_votes] mixing_logits = math_ops.safe_log(self._vote_presence_prob) dummy_logit = tf.zeros([batch_size, 1]) - 2. * tf.log(10.) mixing_logits = tf.concat([mixing_logits, dummy_logit], 1) mixing_log_prob = mixing_logits - tf.reduce_logsumexp( mixing_logits, 1, keepdims=True) expanded_mixing_logits = tf.expand_dims(mixing_log_prob, 1) mixture_log_prob_per_component\ = tf.reduce_logsumexp(expanded_mixing_logits + vote_log_prob, 2) if presence is not None: presence = tf.to_float(presence) mixture_log_prob_per_component *= presence mixture_log_prob_per_example\ = tf.reduce_sum(mixture_log_prob_per_component, 1) mixture_log_prob_per_batch = tf.reduce_mean( mixture_log_prob_per_example) # [B, n_points, n_votes] posterior_mixing_logits_per_point = expanded_mixing_logits + vote_log_prob # [B, n_points] winning_vote_idx = tf.argmax( posterior_mixing_logits_per_point[:, :, :-1], 2) batch_idx = tf.expand_dims(tf.range(batch_size, dtype=tf.int64), -1) batch_idx = snt.TileByDim([1], [winning_vote_idx.shape[-1]])(batch_idx) idx = tf.stack([batch_idx, winning_vote_idx], -1) winning_vote = tf.gather_nd(self._votes, idx) winning_pres = tf.gather_nd(self._vote_presence_prob, idx) vote_presence = tf.greater(mixing_logits[:, :-1], mixing_logits[:, -1:]) # the first four votes belong to the square is_from_capsule = winning_vote_idx // self._n_votes posterior_mixing_probs = tf.nn.softmax( posterior_mixing_logits_per_point, -1)[Ellipsis, :-1] assert winning_vote.shape == x.shape return self.OutputTuple( log_prob=mixture_log_prob_per_batch, vote_presence=tf.to_float(vote_presence), winner=winning_vote, winner_pres=winning_pres, is_from_capsule=is_from_capsule, mixing_logits=mixing_logits, mixing_log_prob=mixing_log_prob, # TODO(adamrk): this is broken soft_winner=tf.zeros_like(winning_vote), soft_winner_pres=tf.zeros_like(winning_pres), posterior_mixing_probs=posterior_mixing_probs, )
def model_fn(features, labels, mode, params): """Model function.""" del labels # ============================== # Input features # ============================== # [batch_size, query_seq_len] query_inputs = features["query_inputs"] # [batch_size, num_candidates, candidate_seq_len] candidate_inputs = features["candidate_inputs"] # [batch_size, num_candidates, query_seq_len + candidate_seq_len] joint_inputs = features["joint_inputs"] # [batch_size, num_masks] mlm_targets = features["mlm_targets"] mlm_positions = features["mlm_positions"] mlm_mask = features["mlm_mask"] # ============================== # Create modules. # ============================== bert_module = hub.Module( spec=params["bert_hub_module_handle"], name="bert", tags={"train"} if mode == tf.estimator.ModeKeys.TRAIN else {}, trainable=True) hub.register_module_for_export(bert_module, "bert") embedder_module = hub.Module( spec=params["embedder_hub_module_handle"], name="embedder", tags={"train"} if mode == tf.estimator.ModeKeys.TRAIN else {}, trainable=True) hub.register_module_for_export(embedder_module, "embedder") if params["share_embedders"]: query_embedder_module = embedder_module else: query_embedder_module = hub.Module( spec=params["embedder_hub_module_handle"], name="embedder", tags={"train"} if mode == tf.estimator.ModeKeys.TRAIN else {}, trainable=True) hub.register_module_for_export(embedder_module, "query_embedder") # ============================== # Retrieve. # ============================== # [batch_size, projected_size] query_emb = query_embedder_module( inputs=dict( input_ids=query_inputs.token_ids, input_mask=query_inputs.mask, segment_ids=query_inputs.segment_ids), signature="projected") # [batch_size * num_candidates, candidate_seq_len] flat_candidate_inputs, unflatten = flatten_bert_inputs( candidate_inputs) # [batch_size * num_candidates, projected_size] flat_candidate_emb = embedder_module( inputs=dict( input_ids=flat_candidate_inputs.token_ids, input_mask=flat_candidate_inputs.mask, segment_ids=flat_candidate_inputs.segment_ids), signature="projected") # [batch_size, num_candidates, projected_size] unflattened_candidate_emb = unflatten(flat_candidate_emb) # [batch_size, num_candidates] retrieval_score = tf.einsum("BD,BND->BN", query_emb, unflattened_candidate_emb) # ============================== # Read. # ============================== # [batch_size * num_candidates, query_seq_len + candidate_seq_len] flat_joint_inputs, unflatten = flatten_bert_inputs(joint_inputs) # [batch_size * num_candidates, num_masks] flat_mlm_positions, _ = tensor_utils.flatten( tf.tile( tf.expand_dims(mlm_positions, 1), [1, params["num_candidates"], 1])) batch_size, num_masks = tensor_utils.shape(mlm_targets) # [batch_size * num_candidates, query_seq_len + candidates_seq_len] flat_joint_bert_outputs = bert_module( inputs=dict( input_ids=flat_joint_inputs.token_ids, input_mask=flat_joint_inputs.mask, segment_ids=flat_joint_inputs.segment_ids, mlm_positions=flat_mlm_positions), signature="mlm", as_dict=True) # [batch_size, num_candidates] candidate_score = retrieval_score # [batch_size, num_candidates] candidate_log_probs = tf.math.log_softmax(candidate_score) # ============================== # Compute marginal log-likelihood. # ============================== # [batch_size * num_candidates, num_masks] flat_mlm_logits = flat_joint_bert_outputs["mlm_logits"] # [batch_size, num_candidates, num_masks, vocab_size] mlm_logits = tf.reshape( flat_mlm_logits, [batch_size, params["num_candidates"], num_masks, -1]) mlm_log_probs = tf.math.log_softmax(mlm_logits) # [batch_size, num_candidates, num_masks] tiled_mlm_targets = tf.tile( tf.expand_dims(mlm_targets, 1), [1, params["num_candidates"], 1]) # [batch_size, num_candidates, num_masks, 1] tiled_mlm_targets = tf.expand_dims(tiled_mlm_targets, -1) # [batch_size, num_candidates, num_masks, 1] gold_log_probs = tf.batch_gather(mlm_log_probs, tiled_mlm_targets) # [batch_size, num_candidates, num_masks] gold_log_probs = tf.squeeze(gold_log_probs, -1) # [batch_size, num_candidates, num_masks] joint_gold_log_probs = ( tf.expand_dims(candidate_log_probs, -1) + gold_log_probs) # [batch_size, num_masks] marginal_gold_log_probs = tf.reduce_logsumexp(joint_gold_log_probs, 1) # [batch_size, num_masks] float_mlm_mask = tf.cast(mlm_mask, tf.float32) # [] loss = -tf.div_no_nan( tf.reduce_sum(marginal_gold_log_probs * float_mlm_mask), tf.reduce_sum(float_mlm_mask)) # ============================== # Optimization # ============================== num_warmup_steps = min(10000, max(100, int(params["num_train_steps"] / 10))) train_op = optimization.create_optimizer( loss=loss, init_lr=params["learning_rate"], num_train_steps=params["num_train_steps"], num_warmup_steps=num_warmup_steps, use_tpu=params["use_tpu"]) # ============================== # Evaluation # ============================== eval_metric_ops = None if params["use_tpu"] else dict() if mode != tf.estimator.ModeKeys.PREDICT: # [batch_size, num_masks] retrieval_utility = marginal_gold_log_probs - gold_log_probs[:, 0] retrieval_utility *= tf.cast(features["mlm_mask"], tf.float32) # [] retrieval_utility = tf.div_no_nan( tf.reduce_sum(retrieval_utility), tf.reduce_sum(float_mlm_mask)) add_mean_metric("retrieval_utility", retrieval_utility, eval_metric_ops) has_timestamp = tf.cast( tf.greater(features["export_timestamp"], 0), tf.float64) off_policy_delay_secs = ( tf.timestamp() - tf.cast(features["export_timestamp"], tf.float64)) off_policy_delay_mins = off_policy_delay_secs / 60.0 off_policy_delay_mins *= tf.cast(has_timestamp, tf.float64) add_mean_metric("off_policy_delay_mins", off_policy_delay_mins, eval_metric_ops) # Create empty predictions to avoid errors when running in prediction mode. predictions = dict() if params["use_tpu"]: return tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, predictions=predictions) else: if eval_metric_ops is not None: # Make sure the eval metrics are updated during training so that we get # quick feedback from tensorboard summaries when debugging locally. with tf.control_dependencies([u for _, u in eval_metric_ops.values()]): loss = tf.identity(loss) return tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops, predictions=predictions)
def iwae(p_z, p_x_given_z, q_z, observations, num_samples, cvs, contexts=None, antithetic=False): """Computes a gradient of the IWAE estimator. Args: p_z: The prior. Should be a callable that optionally accepts a conditioning context and returns a tfp.distributions.Distribution which has the log_prob and sample methods implemented. The distribution should be over a [batch_size, latent_dim] space. p_x_given_z: The likelihood. Should be a callable that accepts as input a tensor of shape [num_samples, batch_size, latent_size + context_size] and returns a tfd.Distribution over a [num_samples, batch_size, data_dim] space. q_z: The proposal, should be a callable which accepts a batch of observations of shape [batch_size, data_dim] and returns a distribution over [batch_size, latent_dim]. observations: A float Tensor of shape [batch_size, data_dim] containing the observations. num_samples: The number of samples for the IWAE estimator. cvs: Control variate variables. contexts: A float Tensor of shape [batch_size, context_dim] containing the contexts. (Optionally, none) antithetic: Whether to use antithetic sampling. Returns: estimators: Dictionary of tuples (objective, neg_model_loss, neg_inference_network_loss). """ alpha, beta, gamma, delta = cvs batch_size = tf.shape(observations)[0] proposal = q_z(observations, contexts, stop_gradient=False) # [num_samples, batch_size, latent_size] # If antithetic sampling, draw half of the samples and use the antithetics # for the other half. if antithetic: z_pos = proposal.sample(sample_shape=[num_samples // 2]) z_neg = 2 * proposal.loc - z_pos z = tf.concat((z_pos, z_neg), axis=0) else: z = proposal.sample(sample_shape=[num_samples]) tiled_contexts = None if contexts is not None: tiled_contexts = tf.tile(tf.expand_dims(contexts, 0), [num_samples, 1, 1]) likelihood = p_x_given_z(z, tiled_contexts) # Before reduce_sum is [num_samples, batch_size, latent_dim]. # Sum over the latent dim. log_q_z = tf.reduce_sum(proposal.log_prob(z), axis=-1) # Before reduce_sum is [num_samples, batch_size, latent_dim]. # Sum over latent dim. prior = p_z(contexts) log_p_z = tf.reduce_sum(prior.log_prob(z), axis=-1) # Before reduce_sum is [num_samples, batch_size, data_dim] log_p_x_given_z = tf.reduce_sum(likelihood.log_prob(observations), axis=-1) log_weights = log_p_z + log_p_x_given_z - log_q_z log_sum_weight = tf.reduce_logsumexp(log_weights, axis=0) log_avg_weight = log_sum_weight - tf.log(tf.to_float(num_samples)) normalized_weights = tf.stop_gradient(tf.nn.softmax(log_weights, axis=0)) if FLAGS.image_summary: best_index = tf.to_int32(tf.argmax(normalized_weights, axis=0)) indices = tf.stack((best_index, tf.range(0, batch_size)), axis=-1) best_images = tf.gather_nd(likelihood.probs_parameter(), indices) if FLAGS.dataset == "struct_mnist": tf.summary.image("bottom_half", tf.reshape(best_images, [batch_size, -1, 28, 1])) else: tf.summary.image("output", tf.reshape(best_images, [batch_size, -1, 28, 1])) tf.summary.image("input", tf.reshape(observations, [batch_size, -1, 28, 1])) # Compute gradient estimators model_loss = log_avg_weight estimators = {} estimators["iwae"] = (log_avg_weight, log_avg_weight, log_avg_weight) stopped_z_log_q_z = tf.reduce_sum(proposal.log_prob(tf.stop_gradient(z)), axis=-1) estimators["rws"] = (log_avg_weight, model_loss, tf.reduce_sum(normalized_weights * stopped_z_log_q_z, axis=0)) # Doubly reparameterized stopped_proposal = q_z(observations, contexts, stop_gradient=True) stopped_log_q_z = tf.reduce_sum(stopped_proposal.log_prob(z), axis=-1) stopped_log_weights = log_p_z + log_p_x_given_z - stopped_log_q_z sq_normalized_weights = tf.square(normalized_weights) estimators["stl"] = (log_avg_weight, model_loss, tf.reduce_sum(normalized_weights * stopped_log_weights, axis=0)) estimators["dreg"] = (log_avg_weight, model_loss, tf.reduce_sum(sq_normalized_weights * stopped_log_weights, axis=0)) estimators["rws-dreg"] = ( log_avg_weight, model_loss, tf.reduce_sum( (normalized_weights - sq_normalized_weights) * stopped_log_weights, axis=0)) # Add normed versions normalized_sq_normalized_weights = ( sq_normalized_weights / tf.reduce_sum(sq_normalized_weights, axis=0, keepdims=True)) estimators["dreg-norm"] = (log_avg_weight, model_loss, tf.reduce_sum(normalized_sq_normalized_weights * stopped_log_weights, axis=0)) rws_dregs_weights = normalized_weights - sq_normalized_weights normalized_rws_dregs_weights = rws_dregs_weights / tf.reduce_sum( rws_dregs_weights, axis=0, keepdims=True) estimators["rws-dreg-norm"] = (log_avg_weight, model_loss, tf.reduce_sum(normalized_rws_dregs_weights * stopped_log_weights, axis=0)) estimators["dreg-alpha"] = (log_avg_weight, model_loss, (1 - FLAGS.alpha) * estimators["dreg"][-1] + FLAGS.alpha * estimators["rws-dreg"][-1]) # Jackknife loo_log_weights = tf.tile(tf.expand_dims(tf.transpose(log_weights), -1), [1, 1, num_samples]) loo_log_weights = tf.matrix_set_diag( loo_log_weights, -np.inf * tf.ones([batch_size, num_samples])) loo_log_avg_weight = tf.reduce_mean( tf.reduce_logsumexp(loo_log_weights, axis=1) - tf.log(tf.to_float(num_samples - 1)), axis=-1) jk_model_loss = num_samples * log_avg_weight - (num_samples - 1) * loo_log_avg_weight estimators["jk"] = (jk_model_loss, jk_model_loss, jk_model_loss) # Compute JK w/ DReG for the inference network loo_normalized_weights = tf.reduce_mean(tf.square( tf.stop_gradient(tf.nn.softmax(loo_log_weights, axis=1))), axis=-1) estimators["jk-dreg"] = ( jk_model_loss, jk_model_loss, num_samples * tf.reduce_sum(sq_normalized_weights * stopped_log_weights, axis=0) - (num_samples - 1) * tf.reduce_sum( tf.transpose(loo_normalized_weights) * stopped_log_weights, axis=0) ) # Compute control variates loo_baseline = tf.expand_dims(tf.transpose(log_weights), -1) loo_baseline = tf.tile(loo_baseline, [1, 1, num_samples]) loo_baseline = tf.matrix_set_diag( loo_baseline, -np.inf * tf.ones_like(tf.transpose(log_weights))) loo_baseline = tf.reduce_logsumexp(loo_baseline, axis=1) loo_baseline = tf.transpose(loo_baseline) learning_signal = tf.stop_gradient(tf.expand_dims( log_avg_weight, 0)) - (1 - gamma) * tf.stop_gradient(loo_baseline) vimco = tf.reduce_sum(learning_signal * stopped_z_log_q_z, axis=0) first_part = alpha * vimco + (1 - alpha) * tf.reduce_sum( normalized_weights * stopped_log_weights, axis=0) second_part = ((1 - beta) * (tf.reduce_sum( ((1 - delta) / tf.to_float(num_samples) - normalized_weights) * stopped_z_log_q_z, axis=0)) + beta * tf.reduce_sum( (sq_normalized_weights - normalized_weights) * stopped_log_weights, axis=0)) estimators["dreg-cv"] = (log_avg_weight, model_loss, first_part + second_part) return estimators
def _logmatmulexp(a, b): """`matmul` computed in log space.""" return tf.reduce_logsumexp(a[..., :, :, None] + b[..., None, :, :], axis=-2)
def marginal_log_loss(logits, is_correct): """Loss based on the negative marginal log-likelihood.""" # [] log_numerator = tf.reduce_logsumexp(logits + mask_to_score(is_correct), -1) log_denominator = tf.reduce_logsumexp(logits, -1) return log_denominator - log_numerator
def log_prob(self, x): x = tf.expand_dims(x, 1) lp = self._component_log_prob(x) return tf.reduce_logsumexp(lp + self.mixing_log_prob, 1)
def mixing_log_prob(self): return self._mixing_logits - tf.reduce_logsumexp( self._mixing_logits, 1, keepdims=True)
def model_fn(features, labels, mode, params, config): """Builds the model function for use in an estimator. Arguments: features: The input features for the estimator. labels: The labels, unused here. mode: Signifies whether it is train or test or predict. params: Some parameters, unused here. config: The RunConfig, unused here. Returns: EstimatorSpec: A tf.estimator.EstimatorSpec instance. """ del labels, params, config if FLAGS.analytic_kl and FLAGS.mixture_components != 1: raise NotImplementedError( "Using `analytic_kl` is only supported when `mixture_components = 1` " "since there's no closed form otherwise.") if FLAGS.floating_prior and not (FLAGS.unit_posterior and FLAGS.mixture_components == 1): raise NotImplementedError( "Using `floating_prior` is only supported when `unit_posterior` = True " "since there's a scale ambiguity otherwise, and when " "`mixture_components = 1` since there's no closed form otherwise.") if FLAGS.fitted_samples and FLAGS.mixture_components != 1: raise NotImplementedError( "Using `fitted_samples` is only supported when " "`mixture_components = 1` since there's no closed form otherwise.") if FLAGS.bilbo and not FLAGS.floating_prior: raise NotImplementedError( "Using `bilbo` is only supported when `floating_prior = True`.") activation = tf.nn.leaky_relu encoder = make_encoder(activation, FLAGS.latent_size, FLAGS.base_depth) decoder = make_decoder(activation, FLAGS.latent_size, [IMAGE_SIZE] * 2 + [3], FLAGS.base_depth) approx_posterior = encoder(features) approx_posterior_sample = approx_posterior.sample(FLAGS.n_samples) decoder_mu = decoder(approx_posterior_sample) if FLAGS.floating_prior or FLAGS.fitted_samples: posterior_batch_mean = tf.reduce_mean(approx_posterior.mean()**2, [0]) posterior_batch_variance = tf.reduce_mean(approx_posterior.stddev()**2, [0]) posterior_scale = posterior_batch_mean + posterior_batch_variance floating_prior = tfd.MultivariateNormalDiag( tf.zeros(FLAGS.latent_size), tf.sqrt(posterior_scale)) tf.summary.scalar("posterior_scale", tf.reduce_sum(posterior_scale)) if FLAGS.floating_prior: latent_prior = floating_prior else: latent_prior = make_mixture_prior(FLAGS.latent_size, FLAGS.mixture_components) # Decode samples from the prior for visualization. if FLAGS.fitted_samples: sample_distribution = floating_prior else: sample_distribution = latent_prior n_samples = VIZ_GRID_SIZE**2 random_mu = decoder(sample_distribution.sample(n_samples)) residual = tf.reshape(features - decoder_mu, [-1] + [IMAGE_SIZE] * 2 + [3]) if FLAGS.use_students_t: lossfun = adaptive.AdaptiveImageLossFunction( residual.shape[1:], residual.dtype, color_space=FLAGS.color_space, representation=FLAGS.representation, wavelet_num_levels=FLAGS.wavelet_num_levels, wavelet_scale_base=FLAGS.wavelet_scale_base, use_students_t=FLAGS.use_students_t, scale_lo=FLAGS.scale_lo, scale_init=FLAGS.scale_init) else: lossfun = adaptive.AdaptiveImageLossFunction( residual.shape[1:], residual.dtype, color_space=FLAGS.color_space, representation=FLAGS.representation, wavelet_num_levels=FLAGS.wavelet_num_levels, wavelet_scale_base=FLAGS.wavelet_scale_base, use_students_t=FLAGS.use_students_t, alpha_lo=FLAGS.alpha_lo, alpha_hi=FLAGS.alpha_hi, alpha_init=FLAGS.alpha_init, scale_lo=FLAGS.scale_lo, scale_init=FLAGS.scale_init) nll = lossfun(residual) nll = tf.reshape(nll, [tf.shape(decoder_mu)[0], tf.shape(decoder_mu)[1]] + [IMAGE_SIZE] * 2 + [3]) # Clipping to prevent the loss from nanning out. max_val = np.finfo(np.float32).max nll = tf.clip_by_value(nll, -max_val, max_val) viz_n_inputs = np.int32(np.minimum(VIZ_MAX_N_INPUTS, FLAGS.batch_size)) viz_n_samples = np.int32(np.minimum(VIZ_MAX_N_SAMPLES, FLAGS.n_samples)) image_tile_summary("input", tf.to_float(features), rows=1, cols=viz_n_inputs) image_tile_summary("recon/mean", decoder_mu[:viz_n_samples, :viz_n_inputs], rows=viz_n_samples, cols=viz_n_inputs) img_summary_input = image_tile_summary("input1", tf.to_float(features), rows=viz_n_inputs, cols=1) img_summary_recon = image_tile_summary("recon1", decoder_mu[:1, :viz_n_inputs], rows=viz_n_inputs, cols=1) image_tile_summary("random/mean", random_mu, rows=VIZ_GRID_SIZE, cols=VIZ_GRID_SIZE) distortion = tf.reduce_sum(nll, axis=[2, 3, 4]) avg_distortion = tf.reduce_mean(distortion) tf.summary.scalar("distortion", avg_distortion) if FLAGS.analytic_kl: rate = tfd.kl_divergence(approx_posterior, latent_prior) else: rate = (approx_posterior.log_prob(approx_posterior_sample) - latent_prior.log_prob(approx_posterior_sample)) avg_rate = tf.reduce_mean(rate) tf.summary.scalar("rate", avg_rate) elbo_local = -(rate + distortion) elbo = tf.reduce_mean(elbo_local) tf.summary.scalar("elbo", elbo) if FLAGS.bilbo: bilbo = -0.5 * tf.reduce_sum( tf.log1p(posterior_batch_mean / posterior_batch_variance)) - avg_distortion tf.summary.scalar("bilbo", bilbo) loss = -bilbo else: loss = -elbo importance_weighted_elbo = tf.reduce_mean( tf.reduce_logsumexp(elbo_local, axis=0) - tf.math.log(tf.to_float(FLAGS.n_samples))) tf.summary.scalar("elbo/importance_weighted", importance_weighted_elbo) # Perform variational inference by minimizing the -ELBO. global_step = tf.train.get_or_create_global_step() learning_rate = tf.train.cosine_decay( FLAGS.learning_rate, tf.maximum(tf.cast(0, tf.int64), global_step - int(FLAGS.decay_start * FLAGS.max_steps)), int((1. - FLAGS.decay_start) * FLAGS.max_steps)) tf.summary.scalar("learning_rate", learning_rate) optimizer = tf.train.AdamOptimizer(learning_rate) if mode == tf_estimator.ModeKeys.TRAIN: train_op = optimizer.minimize(loss, global_step=global_step) else: train_op = None eval_metric_ops = {} eval_metric_ops["elbo"] = tf.metrics.mean(elbo) eval_metric_ops["elbo/importance_weighted"] = tf.metrics.mean( importance_weighted_elbo) eval_metric_ops["rate"] = tf.metrics.mean(avg_rate) eval_metric_ops["distortion"] = tf.metrics.mean(avg_distortion) # This ugly hackery is necessary to get TF to visualize when running the # eval set, apparently. eval_metric_ops["img_summary_input"] = (img_summary_input, tf.no_op()) eval_metric_ops["img_summary_recon"] = (img_summary_recon, tf.no_op()) eval_metric_ops = {str(k): v for k, v in eval_metric_ops.items()} return tf_estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops, )
def compute_iw_marginal(self, targets, targets_mask, decoder_self_attention_bias, features, n_samples, reduce_mean=True, **kwargs): hparams = self._hparams z_q, log_q_z, _ = self.sample_q(targets, targets_mask, decoder_self_attention_bias, n_samples=n_samples, temp=1.0, **kwargs) # [K*B, L, C] iw_kwargs = { key: ops.prepare_for_iw(value, n_samples) for (key, value) in kwargs.items() } iw_targets_mask = ops.prepare_for_iw(targets_mask, n_samples) iw_decoder_self_attention_bias = ( common_attention.attention_bias_ignore_padding(1.0 - iw_targets_mask)) iw_features = copy.copy(features) iw_features["targets"] = ops.prepare_for_iw(features["targets"], n_samples) log_p_z_base, log_abs_det = self.compute_prior_log_prob( z_q, iw_targets_mask, iw_decoder_self_attention_bias, check_invertibility=False, **iw_kwargs) log_p_z = log_p_z_base + log_abs_det body_output = ops.decoder("decoder", z_q, hparams, iw_decoder_self_attention_bias, **iw_kwargs) logits = self.top(body_output, iw_features) numerator, denominator = self.loss_iw(logits, iw_features) numerator = tf.reduce_sum(numerator[..., 0, 0], 1) # [K*B] denominator = tf.reduce_sum(denominator[..., 0, 0], 1) # [K*B] log_p_x = -1 * numerator / denominator log_q_z = gops.reduce_mean_over_l_sum_over_c(log_q_z, iw_targets_mask) log_p_z = log_p_z / tf.reduce_sum(iw_targets_mask, 1) log_p_x, log_q_z, log_p_z = [ ops.unprepare_for_iw(ii, n_samples) for ii in [log_p_x, log_q_z, log_p_z] ] log_w_n = log_p_z - log_q_z log_w_n = tf.nn.log_softmax(log_w_n, axis=0) # [K, B] iw_marginal = log_p_x + log_w_n iw_marginal = tf.reduce_logsumexp(iw_marginal, 0) # [B] if reduce_mean: iw_marginal = tf.cast(tf.reduce_mean(iw_marginal, 0), tf.float32) # [1] else: iw_marginal = tf.cast(iw_marginal, tf.float32) # [1] return iw_marginal
def model_uncertainty(logits): """Mutual information between the categorical label and the model parameters. A way to evaluate uncertainty in ensemble models is to measure its spread or `disagreement`. One way is to measure the mutual information between the categorical label and the parameters of the categorical output. This assesses uncertainty in predictions due to `model uncertainty`. Model uncertainty can be expressed as the difference of the total uncertainty and the expected data uncertainty: `Model uncertainty = Total uncertainty - Expected data uncertainty`, where * `Total uncertainty`: Entropy of expected predictive distribution. * `Expected data uncertainty`: Expected entropy of individual predictive distribution. This formulation was given by [1, 2] and allows the decomposition of total uncertainty into model uncertainty and expected data uncertainty. The total uncertainty will be high whenever the model is uncertain. However, the model uncertainty, the difference between total and expected data uncertainty, will be non-zero iff the ensemble disagrees. ## References: [1] Depeweg, S., Hernandez-Lobato, J. M., Doshi-Velez, F, and Udluft, S. Decomposition of uncertainty for active learning and reliable reinforcement learning in stochastic systems. stat 1050, p.11, 2017. [2] Malinin, A., Mlodozeniec, B., and Gales, M. Ensemble Distribution Distillation. arXiv:1905.00076, 2019. Args: logits: Tensor, shape (N, k, nc). Logits for N instances, k ensembles and nc classes. Raises: TypeError: Raised if both logits and probabilities are not set or both are set. ValueError: Raised if logits or probabilities do not conform to expected shape. Returns: model_uncertainty: Tensor, shape (N,). total_uncertainty: Tensor, shape (N,). expected_data_uncertainty: Tensor, shape (N,). """ if logits is None: raise TypeError("model_uncertainty expected logits to be set.") if tf.rank(logits).numpy() != 3: raise ValueError( "model_uncertainty expected logits to be of shape (N, k, nc)," "instead got {}".format(logits.shape)) # expected data uncertainty log_prob = tf.math.log_softmax(logits, -1) prob = tf.exp(log_prob) expected_data_uncertainty = tf.reduce_mean( tf.reduce_sum(-prob * log_prob, -1), -1) n_ens = tf.cast(log_prob.shape[1], tf.float32) log_expected_probabilities = tf.reduce_logsumexp(log_prob, 1) - tf.math.log(n_ens) expected_probabilities = tf.exp(log_expected_probabilities) total_uncertainty = tf.reduce_sum( -expected_probabilities * log_expected_probabilities, -1) model_uncertainty_ = total_uncertainty - expected_data_uncertainty return model_uncertainty_, total_uncertainty, expected_data_uncertainty
def _log_prob_from_logits(logits): return logits - tf.reduce_logsumexp(logits, axis=2, keep_dims=True)
def _build_train_op(self): """Builds a training op. Returns: train_op: An op performing one step of training. """ target_distribution = tf.stop_gradient( self._build_target_distribution()) # size of indices: batch_size x 1. indices = tf.range(tf.shape(self._replay_net_outputs.logits)[0])[:, None] # size of reshaped_actions: batch_size x 2. reshaped_actions = tf.concat([indices, self._replay.actions[:, None]], 1) # For each element of the batch, fetch the logits for its selected action. chosen_action_logits = tf.gather_nd(self._replay_net_outputs.logits, reshaped_actions) bellman_errors = (target_distribution[:, None, :] - chosen_action_logits[:, :, None] ) # Input `u' of Eq. 9. huber_loss = tf.to_float( # Eq. 9 of paper. tf.abs(bellman_errors) <= self.kappa) * 0.5 * bellman_errors**2 + tf.to_float( tf.abs(bellman_errors) > self.kappa) * self.kappa * ( tf.abs(bellman_errors) - 0.5 * self.kappa) tau_hat = ( tf.range(self._num_atoms, dtype=tf.float32) + 0.5 ) / self._num_atoms # Quantile midpoints. See Lemma 2 of paper. quantile_huber_loss = ( # Eq. 10 of paper. tf.abs(tau_hat[None, :, None] - tf.to_float(bellman_errors < 0)) * huber_loss) # Sum over tau dimension, average over target value dimension. loss = tf.reduce_sum(tf.reduce_mean(quantile_huber_loss, 2), 1) if self._replay_scheme == "prioritized": target_priorities = self._replay.tf_get_priority( self._replay.indices) # The original prioritized experience replay uses a linear exponent # schedule 0.4 -> 1.0. Comparing the schedule to a fixed exponent of 0.5 # on 5 games (Asterix, Pong, Q*Bert, Seaquest, Space Invaders) suggested # a fixed exponent actually performs better, except on Pong. loss_weights = 1.0 / tf.sqrt(target_priorities + 1e-10) loss_weights /= tf.reduce_max(loss_weights) # Rainbow and prioritized replay are parametrized by an exponent alpha, # but in both cases it is set to 0.5 - for simplicity's sake we leave it # as is here, using the more direct tf.sqrt(). Taking the square root # "makes sense", as we are dealing with a squared loss. # Add a small nonzero value to the loss to avoid 0 priority items. While # technically this may be okay, setting all items to 0 priority will cause # troubles, and also result in 1.0 / 0.0 = NaN correction terms. update_priorities_op = self._replay.tf_set_priority( self._replay.indices, tf.sqrt(loss + 1e-10)) # Weight loss by inverse priorities. loss = loss_weights * loss else: update_priorities_op = tf.no_op() ### Add the CQL loss replay_action_one_hot = tf.one_hot(self._replay.actions, self.num_actions, 1.0, 0.0, name="action_one_hot") replay_chosen_q = tf.reduce_sum( self._replay_net_outputs.q_values * replay_action_one_hot, reduction_indices=1, name="replay_chosen_q", ) dataset_expec = tf.reduce_mean(replay_chosen_q) negative_sampling = tf.reduce_mean( tf.reduce_logsumexp(self._replay_net_outputs.q_values, 1)) min_q_loss = negative_sampling - dataset_expec print("MIN Q WEIGHT: ", self.minq_weight) with tf.control_dependencies([update_priorities_op]): if self.summary_writer is not None: with tf.variable_scope("Losses"): tf.summary.scalar("QuantileLoss", tf.reduce_mean(loss)) tf.summary.scalar("minQLoss", tf.reduce_mean(min_q_loss)) tf.summary.scalar("Q_predictions", tf.reduce_mean(replay_chosen_q)) min_q_loss = min_q_loss * self.minq_weight return self.optimizer.minimize(tf.reduce_mean(loss) + min_q_loss), loss