def get_kl_loss(true_probs, pred_probs): assert len(true_probs.shape) == len(pred_probs.shape) import tensorflow.distributions as tfd true = tfd.Categorical(probs=true_probs + 1e-9) pred = tfd.Categorical(probs=pred_probs + 1e-9) pointwise_diverge = tfd.kl_divergence(true, pred, allow_nan_stats=False) kl_loss = tf.reduce_sum(pointwise_diverge) return kl_loss
def _action_selection(self, current_state, G, time, is_training): # TODO: use pi = softmax(-F - gamma*G) instead? gamma = 1 pi_logits = tf.nn.log_softmax(gamma * G) # Incorporate past into the decision. But what to use for the 2 decision actions? # pi_logits = tf.nn.log_softmax(tf.log(new_state['c']) + gamma * G) # TODO: precision? # TODO: which version of action selection? # Visual foraging code: a_t ~ softmax(alpha * log(softmax(-F - gamma * G))) with alpha=512 # Visual foraging paper: a_t = min_a[ o*_t+1 * [log(o*_t+1) - log(o^a_t+1)] ] # Maze code: a_t ~ softmax(gamma * G) [summed over policies with the same next action] selected_action_idx = tf.cond(is_training, lambda: tfd.Categorical(logits=self.alpha * pi_logits, allow_nan_stats=False).sample(), lambda: tf.argmax(G, axis=1, output_type=tf.int32), name='sample_action_cond') # give back the action itself, not its index. Differentiate between decision and location actions best_belief = self._best_believe(current_state) dec = tf.equal(selected_action_idx, self.n_policies) # the last action is the decision selected_action_idx = tf.where(tf.stop_gradient(dec), tf.fill([self.B], 0), selected_action_idx) # replace decision indeces (which exceed the shape of selected_action), so we can use gather on the locations decision = tf.cond(tf.equal(time, self.num_glimpses - 1), lambda: best_belief, # always take a decision at the last time step lambda: tf.where(dec, best_belief, tf.fill([self.B], -1)), name='last_t_decision_cond') return decision, selected_action_idx
def _discrete_entropy_agg(d, logits=None, probs=None, agg=True): # TODO: DOES MEAN MAKE SENSE? (at least better than sum, as indifferent to size_z) if d == 'B': dist = tfd.Bernoulli(logits=logits, probs=probs) elif d == 'Cat': dist = tfd.Categorical(logits=logits, probs=probs) H = dist.entropy() if agg: H = tf.reduce_mean(H, axis=-1) # [B, n_policies, hyp] return H
def log_prob(self, zs, xs, T, z_lens, x_lens): """Computes the log probability of a set of samples. Args: zs: A set of [batch_size, max_z_num_timesteps, state_dim] latent states. xs: A set of [batch_size, max_x_num_timesteps, state_dim] observations. T: A set of [batch_size] integers denoting the number of censored steps. z_lens: A set of [batch_size] integers denoting the length of each sequence of zs. x_lens: A set of [batch_size] integers denoting the length of each sequence of observations. Note that T must equal z_lens - x_lens. Returns: log_p_z: A [batch_size, max_z_num_timesteps] set of logprobs of zs. log_p_x_given_z: A [batch_size, max_x_num_timesteps] set of logprobs of xs. log_p_T: A [batch_size] set of logprobs of T. """ # First, reverse the zs rev_zs = tf.reverse_sequence(zs, z_lens, seq_axis=1, batch_axis=0) batch_size = tf.shape(zs)[0] # Compute means of z locations by adding drift to each z rev_z_locs = rev_zs[:, :-1, :] + self.drift[tf.newaxis, tf.newaxis, :] z0_mu = tf.tile(self.z0_mu[tf.newaxis, tf.newaxis, :], [batch_size, 1, 1]) rev_z_locs = tf.concat([z0_mu, rev_z_locs], axis=1) # Compute z log probs. rev_log_p_z = tfd.Normal(loc=rev_z_locs, scale=self.z_scale).log_prob(rev_zs) rev_log_p_z *= tf.sequence_mask(z_lens, dtype=rev_log_p_z.dtype)[:, :, tf.newaxis] # Reverse the log probs back log_p_z = tf.reverse_sequence(rev_log_p_z, z_lens, seq_axis=1, batch_axis=0) log_p_z = tf.reduce_sum(log_p_z, axis=-1) # To compute the prob of xs, mask out all zs beyond the first x_len masked_zs = zs * tf.sequence_mask(x_lens, maxlen=tf.reduce_max(z_lens), dtype=zs.dtype)[:, :, tf.newaxis] masked_zs = masked_zs[:, :tf.reduce_max(x_lens), :] log_p_x_given_z = tfd.Normal(loc=masked_zs, scale=self.x_scale).log_prob(xs) log_p_x_given_z *= tf.sequence_mask( x_lens, dtype=log_p_x_given_z.dtype)[:, :, tf.newaxis] log_p_x_given_z = tf.reduce_sum(log_p_x_given_z, axis=-1) log_p_T = tfd.Categorical(logits=self.T_logits).log_prob(T) return log_p_z, log_p_x_given_z, log_p_T
def textGenerate(self): # self.theta_part=tf.reduce_sum(tf.expand_dims(tf.expand_dims(1-self.stop_prob,-1)*tf.expand_dims(theta_gen,1),-1)*self.token_ppx_non_prob,2) pred_next_token_theta = dist.Categorical(probs=self.h_part).sample() return pred_next_token_theta
def sample(self, batch_size, xs, x_lens): max_seq_len = tf.reduce_max(x_lens) rev_xs = tf.reverse_sequence(xs, x_lens, seq_axis=1, batch_axis=0) # Sample T T_logits = tf.matmul(rev_xs[:, 0, :], self.W_T) + self.b_T[tf.newaxis, :] q_T = tfd.Categorical(logits=T_logits) T = tf.stop_gradient(q_T.sample()) z_lens = T + x_lens log_q_T = q_T.log_prob(T) rev_zs_ta = tf.TensorArray(dtype=self.dtype, size=max_seq_len, dynamic_size=True, name="sample_zs") rev_log_q_z_ta = tf.TensorArray(dtype=self.dtype, size=max_seq_len, dynamic_size=True, name="log_q_z_ta") z0 = tf.zeros([batch_size, self.state_size], dtype=self.dtype) t0 = 0 def while_predicate(t, *unused_args): return tf.reduce_any(t < T + x_lens) def while_step(t, prev_z, rev_log_q_z_ta, rev_zs_ta): # Compute the distribution over z_{T-t} # [batch_size] steps till next x steps_till_next_x = tf.maximum(T - t, 0) # Fetch the next x value. next_x_ind = tf.minimum(tf.maximum(t - T, 0), x_lens - 1) r = tf.range(0, batch_size) inds = tf.stack([r, next_x_ind], axis=-1) x = tf.gather_nd(rev_xs, inds) z_loc_input = tf.concat( [x, prev_z, tf.to_float(steps_till_next_x)[:, tf.newaxis]], axis=1) z_loc = tf.matmul(z_loc_input, self.W_z) + self.b_z[tf.newaxis, :] log_sigmas = tf.gather(self.log_sigma, steps_till_next_x) z_scale = tf.math.maximum(tf.math.softplus(log_sigmas), self.sigma_min) q_z = tfd.Normal(loc=z_loc, scale=z_scale) new_z = q_z.sample() log_q_new_z = q_z.log_prob(new_z) new_z = tf.where(t < z_lens, new_z, tf.zeros_like(new_z)) log_q_new_z = tf.where(t < z_lens, log_q_new_z, tf.zeros_like(log_q_new_z)) new_rev_log_q_z_ta = rev_log_q_z_ta.write(t, log_q_new_z) new_rev_zs_ta = rev_zs_ta.write(t, new_z) return t + 1, new_z, new_rev_log_q_z_ta, new_rev_zs_ta # xs are currently [batch_size, steps, state_size]. # we transpose to [steps, batch_size, state_size] so that scan unpacks along # the first dimension. _, _, rev_log_q_z_ta, rev_zs_ta = tf.while_loop( while_predicate, while_step, loop_vars=(t0, z0, rev_log_q_z_ta, rev_zs_ta), parallel_iterations=1) # rev_zs are currently [time, batch_size, state_dim]. # We transpose to [batch_size, time, state_dim] to be consistent. rev_zs = tf.transpose(rev_zs_ta.stack(), [1, 0, 2]) zs = tf.reverse_sequence(rev_zs, z_lens, seq_axis=1, batch_axis=0) # Sum the log q(z) over the state dimension and then transpose, # resulting in a [batch_size, time] Tensor. rev_log_q_z = tf.transpose( tf.reduce_sum(rev_log_q_z_ta.stack(), axis=-1), [1, 0]) log_q_z = tf.reverse_sequence(rev_log_q_z, z_lens, seq_axis=1, batch_axis=0) return T, log_q_T, zs, log_q_z
def _action_selection(self, next_actions, next_actions_mean, new_state, G, exp_obs_prior, time, is_training): # TODO: should uniformLoc10 take random decisions or not? if self.actInfPolicy in ['random', 'uniformLoc10']: selected_action_idx = tf.random_uniform(shape=[self.B], minval=0, maxval=self.n_policies, dtype=tf.int32) if time < (self.num_glimpses - 1): decision = tf.fill([self.B], -1) else: decision = self._best_believe(new_state) else: # TODO: use pi = softmax(-F - gamma*G) instead? gamma = 1 pi_logits = tf.nn.log_softmax(gamma * G) # Incorporate past into the decision. But what to use for the 2 decision actions? # pi_logits = tf.nn.log_softmax(tf.log(new_state['c']) + gamma * G) # TODO: precision? # TODO: which version of action selection? # Visual foraging code: a_t ~ softmax(alpha * log(softmax(-F - gamma * G))) with alpha=512 # Visual foraging paper: a_t = min_a[ o*_t+1 * [log(o*_t+1) - log(o^a_t+1)] ] # Maze code: a_t ~ softmax(gamma * G) [summed over policies with the same next action] selected_action_idx = tf.cond( is_training, lambda: tfd.Categorical(logits=self.alpha * pi_logits, allow_nan_stats=False).sample(), lambda: tf.argmax(G, axis=1, output_type=tf.int32), name='sample_action_cond') # give back the action itself, not its index. Differentiate between decision and location actions best_belief = self._best_believe(new_state) dec = tf.equal(selected_action_idx, self.n_policies) # the last action is the decision selected_action_idx = tf.where( tf.stop_gradient(dec), tf.fill([self.B], 0), selected_action_idx ) # replace decision indeces (which exceed the shape of selected_action), so we can use gather on the locations decision = tf.cond( tf.equal(time, self.num_glimpses - 1), lambda: best_belief, # always take a decision at the last time step lambda: tf.where(dec, best_belief, tf.fill([self.B], -1)), name='last_t_decision_cond') if self.n_policies == 1: selected_action, selected_action_mean = next_actions, next_actions_mean selected_exp_obs = { k: tf.reshape(v, [self.B, self.num_classes_kn, v.shape[-1]]) if (v is not None) else None for k, v in exp_obs_prior.items() } # squeeze out policy dim (squeeze would turn shape into unknown) else: coords = tf.stack(tf.meshgrid(tf.range(self.B)) + [selected_action_idx], axis=1) selected_action = tf.gather_nd(next_actions, coords) selected_action_mean = tf.gather_nd(next_actions_mean, coords) selected_exp_obs = { k: tf.gather_nd(v, coords) if (v is not None) else None for k, v in exp_obs_prior.items() } # [B, num_classes_kn, -1] as n_policies get removed in gather_nd return decision, selected_action, selected_action_mean, selected_exp_obs, selected_action_idx
def forward(self, inputs,params, mode="Train"): stop_indicator=tf.to_float(tf.expand_dims(inputs["indicators"],-1)) seq_mask=tf.to_float(tf.sequence_mask(inputs["length"])) target_to_onehot=tf.expand_dims(tf.to_float(tf.one_hot(inputs["targets"],self.vocab_size)),2) '''RNN Cell''' with tf.name_scope("RNN_CELL"): emb = tf.nn.embedding_lookup(self.embedding, inputs["tokens"]) cells = [tf.nn.rnn_cell.GRUCell(self.num_units) for _ in range(self.num_layers)] cell = tf.nn.rnn_cell.MultiRNNCell(cells) rnn_outputs, final_output = tf.nn.dynamic_rnn(cell, inputs=emb, sequence_length=inputs["length"], dtype=tf.float32) ''' Sampling theta q(theta|w;alpha)''' with tf.name_scope("theta"): emb_wo=tf.expand_dims(inputs["frequency"],-1)*tf.nn.embedding_lookup(self.embedding,inputs["targets"]) alpha = tf.nn.softplus(tf.tensordot(emb_wo,self.theta_weight,[[1,2],[0,1]])) self.theta_point=alpha/(tf.expand_dims(tf.reduce_sum(alpha,-1),-1)+1e-10) gamma =params["prior"]*tf.ones_like(alpha) pst_dist = tf.distributions.Dirichlet(alpha) pri_dist = tf.distributions.Dirichlet(gamma) '''kl_divergence for theta''' theta_kl_loss=pst_dist.kl_divergence(pri_dist) theta_kl_loss=tf.reduce_mean(theta_kl_loss,-1) self.theta=pst_dist.sample() ''' Phi Matrix ''' with tf.name_scope("Phi"): self.phi=tf.nn.dropout(tf.nn.softmax(tf.contrib.layers.batch_norm(tf.layers.dense(emb_wo,self.num_topics),-1)),inputs["dropout"]) # self.phi=tf.nn.dropout(tf.nn.softmax(tf.layers.dense(emb_wo,self.num_topics),-1),inputs["dropout"]) self.phi=((1-stop_indicator)*self.phi)+((stop_indicator)*(1./self.num_topics)) '''Token loss (Reconstruction Loss)''' with tf.name_scope("token_loss"): h_prob=tf.expand_dims(tf.nn.softmax(tf.layers.dense(rnn_outputs, units=self.vocab_size, use_bias=False),-1),2) b_prob=tf.expand_dims(tf.pad(tf.nn.softmax(tf.contrib.layers.batch_norm(self.beta),-1),self.paddings,"CONSTANT"),0) token_logits = (1-(params["mixture_lambda"]*(1-tf.expand_dims(stop_indicator,-1))))*h_prob+params["mixture_lambda"]*tf.expand_dims(1-stop_indicator,-1)*b_prob token_loss=tf.log(tf.reduce_sum(target_to_onehot*token_logits,-1)+1e-4) token_loss=seq_mask*tf.reduce_sum(self.phi*token_loss,-1) token_loss = -tf.reduce_mean(tf.reduce_sum(token_loss, axis=-1)) with tf.name_scope("indicator_loss"): # indicator_logits = tf.squeeze(tf.layers.dense(rnn_outputs, units=1,activation=tf.nn.softplus), axis=2) indicator_logits = tf.squeeze(tf.contrib.layers.batch_norm(tf.layers.dense(tf.layers.dense(rnn_outputs, units=5,activation=tf.nn.softplus),units=1,activation=tf.nn.softplus)), axis=2) indicator_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.to_float(inputs["indicators"]),logits=indicator_logits,name="indicator_loss") indicator_loss=tf.reduce_mean(tf.reduce_sum(seq_mask*indicator_loss,-1)) indicator_acc=tf.reduce_mean(tf.to_float(tf.equal(tf.round(tf.nn.sigmoid(indicator_logits)),tf.to_float(inputs["indicators"]))),-1) indicator_acc=tf.reduce_mean(indicator_acc) with tf.name_scope("Perplexity"): k_temp=tf.nn.sigmoid(indicator_logits)*tf.squeeze(tf.reduce_sum(target_to_onehot*h_prob,-1),-1) token_ppl=tf.exp(-tf.reduce_sum(seq_mask*tf.log(tf.reduce_sum(tf.expand_dims(1-tf.nn.sigmoid(indicator_logits),-1)*self.phi*(1-stop_indicator)*tf.reduce_sum(target_to_onehot*((1-params["mixture_lambda"])*h_prob+params["mixture_lambda"]*b_prob),-1),-1)+k_temp+1e-10))/(1e-10+tf.to_float(tf.reduce_sum(inputs["length"])))) with tf.name_scope("TextGenerate"): k_text_temp=tf.expand_dims(tf.nn.sigmoid(indicator_logits),-1)*tf.squeeze(h_prob,2) phi_text_temp=tf.reduce_sum(tf.expand_dims(tf.expand_dims(1-tf.nn.sigmoid(indicator_logits),-1)*self.phi*(1-stop_indicator),-1)*((1-params["mixture_lambda"])*h_prob+params["mixture_lambda"]*b_prob),2) # pred_next_token=tf.argmax(k_text_temp+phi_text_temp,-1) pred_next_token=dist.Categorical(probs=k_text_temp+phi_text_temp).sample() with tf.name_scope("TextGenerateTheta"): # theta_text_temp=tf.reduce_sum(tf.expand_dims(tf.expand_dims(1-tf.nn.sigmoid(indicator_logits),-1)*tf.expand_dims(self.theta,1),-1)*((1-params["mixture_lambda"])*h_prob+params["mixture_lambda"]*b_prob),2) theta_text_temp=tf.reduce_sum(tf.expand_dims(tf.expand_dims(1-tf.nn.sigmoid(indicator_logits),-1)*tf.expand_dims(self.theta_point,1),-1)*((1-params["mixture_lambda"])*h_prob+params["mixture_lambda"]*b_prob),2) # pred_next_token_theta=tf.argmax(k_text_temp+theta_text_temp,-1) pred_next_token_theta=dist.Categorical(probs=k_text_temp+theta_text_temp).sample() print('-'*50) print('pred_next_token_theta',pred_next_token_theta.get_shape()) print('-'*50) # print('pred_next_token',pred_next_token.get_shape()) # print('pred_next_token',pred_next_token.get_shape()) # if inputs["model"]=="Valid": # all_next_probs=tf.reduce_sum(tf.expand_dims(1-tf.nn.sigmoid(indicator_logits),-1)*self.phi*(1-stop_indicator)*((1-params["mixture_lambda"])*h_prob+params["mixture_lambda"]*b_prob)+k_temp # k_temp=tf.nn.sigmoid(indicator_logits) # k_temp=tf.reduce_sum(target_to_onehot*h_prob,-1) # unif_temp=tf.reduce_sum(tf.expand_dims(tf.nn.sigmoid(indicator_logits),-1) # labels_temp=tf.reduce_sum(labels*tf.nn.softmax(indicator_logits),-1) # token_ppl=tf.exp(-tf.reduce_sum(seq_mask*tf.log(phi_temp*labels_temp+1e-10))/(1e-5+tf.to_float(tf.reduce_sum(inputs["length"])))) # ,-1) # +1e-10) # print('phi_temp',phi_temp.get_shape()) # print('k_temp',k_temp.get_shape()) # print('labels_temp',labels_temp.get_shape()) # print('seq_mask',seq_mask.get_shape()) ''' KL between Phi and theta ''' with tf.name_scope("Phi_theta_kl"): theta=tf.expand_dims(self.theta,1) phi_theta_kl_loss=tf.reduce_mean(tf.reduce_sum(tf.squeeze(1-stop_indicator,-1)*tf.reduce_sum((1-stop_indicator)*self.phi*tf.log((((1-stop_indicator)*self.phi)/(theta+1e-10))+1e-10),-1),-1)) total_loss=token_loss+theta_kl_loss+indicator_loss+phi_theta_kl_loss with tf.name_scope("SwitchP"): all_topics=tf.argmax(self.phi,-1) with tf.name_scope("Entropies"): # phi_entropy=tf.reduce_mean(tf.reduce_sum(tf.to_float(1-inputs["indicators"])*tf.reduce_sum(-self.phi*tf.log(self.phi+1e-10),-1),-1)/tf.reduce_sum(tf.to_float(1-inputs["indicators"])),-1) theta_entropy=tf.reduce_mean(tf.reduce_sum(-self.theta*tf.log(self.theta+1e-10),-1)) phi_entropy=tf.reduce_mean(tf.reduce_sum(tf.to_float(1-inputs["indicators"])*tf.reduce_sum(-self.phi*tf.log(self.phi+1e-10),-1),-1)/tf.reduce_sum(tf.to_float(1-inputs["indicators"])),-1) # print('-'*100) # print('theta_entropy',theta_entropy.get_shape()) # print('-'*100) # all_topics=dist.Categorical(probs=self.phi).sample() # cat_topic=dist.Categorical(probs=self.theta) # all_topics=tf.transpose(cat_topic.sample(sample_shape=[self.phi.get_shape()[1]])) # print('all_topics',all_topics.get_shape()) # print('-'*100) # all_topics=tf.self.phi tf.summary.scalar(tensor=token_loss, name=mode+" token_loss") tf.summary.scalar(tensor=phi_theta_kl_loss, name=mode+" phi_theta_kl_loss") tf.summary.scalar(tensor=indicator_loss, name=mode+" indicator_loss") tf.summary.scalar(tensor=theta_kl_loss, name=mode+" theta_kl_loss") tf.summary.scalar(tensor=total_loss, name=mode+" total_loss") tf.summary.scalar(tensor=token_ppl, name=mode+" token_ppl") outputs = { "token_loss": token_loss, "token_ppl": token_ppl, "indicator_loss": indicator_loss, "theta_kl_loss": theta_kl_loss, "phi_theta_kl_loss": phi_theta_kl_loss, "loss": total_loss, "theta": self.theta, "repre": final_output[-1][1], "beta":self.beta, "all_topics": all_topics, "non_stop_indic":1-inputs["indicators"], "phi":self.phi, "pred_next_token":pred_next_token, "accuracy":indicator_acc, "pred_next_token_theta":pred_next_token_theta, "theta_entropy":theta_entropy, # "phi_entropy":phi_entropy } return outputs