Exemple #1
0
def get_kl_loss(true_probs, pred_probs):
    assert len(true_probs.shape) == len(pred_probs.shape)
    import tensorflow.distributions as tfd
    true = tfd.Categorical(probs=true_probs + 1e-9)
    pred = tfd.Categorical(probs=pred_probs + 1e-9)
    pointwise_diverge = tfd.kl_divergence(true, pred, allow_nan_stats=False)
    kl_loss = tf.reduce_sum(pointwise_diverge)
    return kl_loss
    def _action_selection(self, current_state, G, time, is_training):
        # TODO: use pi = softmax(-F - gamma*G) instead?
        gamma = 1
        pi_logits = tf.nn.log_softmax(gamma * G)
        # Incorporate past into the decision. But what to use for the 2 decision actions?
        # pi_logits = tf.nn.log_softmax(tf.log(new_state['c']) + gamma * G)

        # TODO: precision?
        # TODO: which version of action selection?
        # Visual foraging code: a_t ~ softmax(alpha * log(softmax(-F - gamma * G))) with alpha=512
        # Visual foraging paper: a_t = min_a[ o*_t+1 * [log(o*_t+1) - log(o^a_t+1)] ]
        # Maze code: a_t ~ softmax(gamma * G) [summed over policies with the same next action]
        selected_action_idx = tf.cond(is_training,
                                      lambda: tfd.Categorical(logits=self.alpha * pi_logits, allow_nan_stats=False).sample(),
                                      lambda: tf.argmax(G, axis=1, output_type=tf.int32),
                                      name='sample_action_cond')
        # give back the action itself, not its index. Differentiate between decision and location actions
        best_belief = self._best_believe(current_state)
        dec = tf.equal(selected_action_idx, self.n_policies)  # the last action is the decision
        selected_action_idx = tf.where(tf.stop_gradient(dec), tf.fill([self.B], 0), selected_action_idx)  # replace decision indeces (which exceed the shape of selected_action), so we can use gather on the locations

        decision = tf.cond(tf.equal(time, self.num_glimpses - 1),
                           lambda: best_belief,  # always take a decision at the last time step
                           lambda: tf.where(dec, best_belief, tf.fill([self.B], -1)),
                           name='last_t_decision_cond')
        return decision, selected_action_idx
 def _discrete_entropy_agg(d, logits=None, probs=None, agg=True):
     # TODO: DOES MEAN MAKE SENSE? (at least better than sum, as indifferent to size_z)
     if d == 'B':
         dist = tfd.Bernoulli(logits=logits, probs=probs)
     elif d == 'Cat':
         dist = tfd.Categorical(logits=logits, probs=probs)
     H = dist.entropy()
     if agg:
         H = tf.reduce_mean(H, axis=-1)  # [B, n_policies, hyp]
     return H
Exemple #4
0
    def log_prob(self, zs, xs, T, z_lens, x_lens):
        """Computes the log probability of a set of samples.

    Args:
      zs: A set of [batch_size, max_z_num_timesteps, state_dim] latent states.
      xs: A set of [batch_size, max_x_num_timesteps, state_dim] observations.
      T: A set of [batch_size] integers denoting the number of censored steps.
      z_lens: A set of [batch_size] integers denoting the length of each 
        sequence of zs.
      x_lens: A set of [batch_size] integers denoting the length of each
        sequence of observations. Note that T must equal z_lens - x_lens.
    Returns:
      log_p_z: A [batch_size, max_z_num_timesteps] set of logprobs of zs.
      log_p_x_given_z: A [batch_size, max_x_num_timesteps] set of logprobs of xs.
      log_p_T: A [batch_size] set of logprobs of T.
    """
        # First, reverse the zs
        rev_zs = tf.reverse_sequence(zs, z_lens, seq_axis=1, batch_axis=0)
        batch_size = tf.shape(zs)[0]
        # Compute means of z locations by adding drift to each z
        rev_z_locs = rev_zs[:, :-1, :] + self.drift[tf.newaxis, tf.newaxis, :]
        z0_mu = tf.tile(self.z0_mu[tf.newaxis, tf.newaxis, :],
                        [batch_size, 1, 1])
        rev_z_locs = tf.concat([z0_mu, rev_z_locs], axis=1)
        # Compute z log probs.
        rev_log_p_z = tfd.Normal(loc=rev_z_locs,
                                 scale=self.z_scale).log_prob(rev_zs)
        rev_log_p_z *= tf.sequence_mask(z_lens,
                                        dtype=rev_log_p_z.dtype)[:, :,
                                                                 tf.newaxis]
        # Reverse the log probs back
        log_p_z = tf.reverse_sequence(rev_log_p_z,
                                      z_lens,
                                      seq_axis=1,
                                      batch_axis=0)
        log_p_z = tf.reduce_sum(log_p_z, axis=-1)

        # To compute the prob of xs, mask out all zs beyond the first x_len
        masked_zs = zs * tf.sequence_mask(x_lens,
                                          maxlen=tf.reduce_max(z_lens),
                                          dtype=zs.dtype)[:, :, tf.newaxis]
        masked_zs = masked_zs[:, :tf.reduce_max(x_lens), :]
        log_p_x_given_z = tfd.Normal(loc=masked_zs,
                                     scale=self.x_scale).log_prob(xs)
        log_p_x_given_z *= tf.sequence_mask(
            x_lens, dtype=log_p_x_given_z.dtype)[:, :, tf.newaxis]
        log_p_x_given_z = tf.reduce_sum(log_p_x_given_z, axis=-1)

        log_p_T = tfd.Categorical(logits=self.T_logits).log_prob(T)
        return log_p_z, log_p_x_given_z, log_p_T
Exemple #5
0
 def textGenerate(self):
     # self.theta_part=tf.reduce_sum(tf.expand_dims(tf.expand_dims(1-self.stop_prob,-1)*tf.expand_dims(theta_gen,1),-1)*self.token_ppx_non_prob,2)
     pred_next_token_theta = dist.Categorical(probs=self.h_part).sample()
     return pred_next_token_theta
Exemple #6
0
    def sample(self, batch_size, xs, x_lens):

        max_seq_len = tf.reduce_max(x_lens)
        rev_xs = tf.reverse_sequence(xs, x_lens, seq_axis=1, batch_axis=0)

        # Sample T
        T_logits = tf.matmul(rev_xs[:, 0, :],
                             self.W_T) + self.b_T[tf.newaxis, :]
        q_T = tfd.Categorical(logits=T_logits)
        T = tf.stop_gradient(q_T.sample())
        z_lens = T + x_lens
        log_q_T = q_T.log_prob(T)

        rev_zs_ta = tf.TensorArray(dtype=self.dtype,
                                   size=max_seq_len,
                                   dynamic_size=True,
                                   name="sample_zs")
        rev_log_q_z_ta = tf.TensorArray(dtype=self.dtype,
                                        size=max_seq_len,
                                        dynamic_size=True,
                                        name="log_q_z_ta")
        z0 = tf.zeros([batch_size, self.state_size], dtype=self.dtype)
        t0 = 0

        def while_predicate(t, *unused_args):
            return tf.reduce_any(t < T + x_lens)

        def while_step(t, prev_z, rev_log_q_z_ta, rev_zs_ta):
            # Compute the distribution over z_{T-t}

            # [batch_size] steps till next x
            steps_till_next_x = tf.maximum(T - t, 0)
            # Fetch the next x value.
            next_x_ind = tf.minimum(tf.maximum(t - T, 0), x_lens - 1)
            r = tf.range(0, batch_size)
            inds = tf.stack([r, next_x_ind], axis=-1)
            x = tf.gather_nd(rev_xs, inds)

            z_loc_input = tf.concat(
                [x, prev_z,
                 tf.to_float(steps_till_next_x)[:, tf.newaxis]],
                axis=1)
            z_loc = tf.matmul(z_loc_input, self.W_z) + self.b_z[tf.newaxis, :]
            log_sigmas = tf.gather(self.log_sigma, steps_till_next_x)
            z_scale = tf.math.maximum(tf.math.softplus(log_sigmas),
                                      self.sigma_min)
            q_z = tfd.Normal(loc=z_loc, scale=z_scale)
            new_z = q_z.sample()
            log_q_new_z = q_z.log_prob(new_z)

            new_z = tf.where(t < z_lens, new_z, tf.zeros_like(new_z))
            log_q_new_z = tf.where(t < z_lens, log_q_new_z,
                                   tf.zeros_like(log_q_new_z))

            new_rev_log_q_z_ta = rev_log_q_z_ta.write(t, log_q_new_z)
            new_rev_zs_ta = rev_zs_ta.write(t, new_z)
            return t + 1, new_z, new_rev_log_q_z_ta, new_rev_zs_ta

        # xs are currently [batch_size, steps, state_size].
        # we transpose to [steps, batch_size, state_size] so that scan unpacks along
        # the first dimension.
        _, _, rev_log_q_z_ta, rev_zs_ta = tf.while_loop(
            while_predicate,
            while_step,
            loop_vars=(t0, z0, rev_log_q_z_ta, rev_zs_ta),
            parallel_iterations=1)

        # rev_zs are currently [time, batch_size, state_dim].
        # We transpose to [batch_size, time, state_dim] to be consistent.
        rev_zs = tf.transpose(rev_zs_ta.stack(), [1, 0, 2])
        zs = tf.reverse_sequence(rev_zs, z_lens, seq_axis=1, batch_axis=0)
        # Sum the log q(z) over the state dimension and then transpose,
        # resulting in a [batch_size, time] Tensor.
        rev_log_q_z = tf.transpose(
            tf.reduce_sum(rev_log_q_z_ta.stack(), axis=-1), [1, 0])
        log_q_z = tf.reverse_sequence(rev_log_q_z,
                                      z_lens,
                                      seq_axis=1,
                                      batch_axis=0)
        return T, log_q_T, zs, log_q_z
    def _action_selection(self, next_actions, next_actions_mean, new_state, G,
                          exp_obs_prior, time, is_training):
        # TODO: should uniformLoc10 take random decisions or not?
        if self.actInfPolicy in ['random', 'uniformLoc10']:
            selected_action_idx = tf.random_uniform(shape=[self.B],
                                                    minval=0,
                                                    maxval=self.n_policies,
                                                    dtype=tf.int32)
            if time < (self.num_glimpses - 1):
                decision = tf.fill([self.B], -1)
            else:
                decision = self._best_believe(new_state)
        else:
            # TODO: use pi = softmax(-F - gamma*G) instead?
            gamma = 1
            pi_logits = tf.nn.log_softmax(gamma * G)
            # Incorporate past into the decision. But what to use for the 2 decision actions?
            # pi_logits = tf.nn.log_softmax(tf.log(new_state['c']) + gamma * G)

            # TODO: precision?
            # TODO: which version of action selection?
            # Visual foraging code: a_t ~ softmax(alpha * log(softmax(-F - gamma * G))) with alpha=512
            # Visual foraging paper: a_t = min_a[ o*_t+1 * [log(o*_t+1) - log(o^a_t+1)] ]
            # Maze code: a_t ~ softmax(gamma * G) [summed over policies with the same next action]
            selected_action_idx = tf.cond(
                is_training,
                lambda: tfd.Categorical(logits=self.alpha * pi_logits,
                                        allow_nan_stats=False).sample(),
                lambda: tf.argmax(G, axis=1, output_type=tf.int32),
                name='sample_action_cond')
            # give back the action itself, not its index. Differentiate between decision and location actions
            best_belief = self._best_believe(new_state)
            dec = tf.equal(selected_action_idx,
                           self.n_policies)  # the last action is the decision
            selected_action_idx = tf.where(
                tf.stop_gradient(dec), tf.fill([self.B],
                                               0), selected_action_idx
            )  # replace decision indeces (which exceed the shape of selected_action), so we can use gather on the locations

            decision = tf.cond(
                tf.equal(time, self.num_glimpses - 1),
                lambda:
                best_belief,  # always take a decision at the last time step
                lambda: tf.where(dec, best_belief, tf.fill([self.B], -1)),
                name='last_t_decision_cond')

        if self.n_policies == 1:
            selected_action, selected_action_mean = next_actions, next_actions_mean
            selected_exp_obs = {
                k: tf.reshape(v, [self.B, self.num_classes_kn, v.shape[-1]]) if
                (v is not None) else None
                for k, v in exp_obs_prior.items()
            }  # squeeze out policy dim (squeeze would turn shape into unknown)
        else:
            coords = tf.stack(tf.meshgrid(tf.range(self.B)) +
                              [selected_action_idx],
                              axis=1)
            selected_action = tf.gather_nd(next_actions, coords)
            selected_action_mean = tf.gather_nd(next_actions_mean, coords)
            selected_exp_obs = {
                k: tf.gather_nd(v, coords) if (v is not None) else None
                for k, v in exp_obs_prior.items()
            }  # [B, num_classes_kn, -1] as n_policies get removed in gather_nd
        return decision, selected_action, selected_action_mean, selected_exp_obs, selected_action_idx
Exemple #8
0
  def forward(self, inputs,params, mode="Train"):

    stop_indicator=tf.to_float(tf.expand_dims(inputs["indicators"],-1))
    seq_mask=tf.to_float(tf.sequence_mask(inputs["length"]))
    target_to_onehot=tf.expand_dims(tf.to_float(tf.one_hot(inputs["targets"],self.vocab_size)),2)

    '''RNN Cell'''
    with tf.name_scope("RNN_CELL"):
      emb = tf.nn.embedding_lookup(self.embedding, inputs["tokens"])    
      cells = [tf.nn.rnn_cell.GRUCell(self.num_units) for _ in range(self.num_layers)]
      cell = tf.nn.rnn_cell.MultiRNNCell(cells)
      rnn_outputs, final_output = tf.nn.dynamic_rnn(cell, inputs=emb, sequence_length=inputs["length"], dtype=tf.float32)

    ''' Sampling theta q(theta|w;alpha)'''
    with tf.name_scope("theta"):
        emb_wo=tf.expand_dims(inputs["frequency"],-1)*tf.nn.embedding_lookup(self.embedding,inputs["targets"])          
        alpha = tf.nn.softplus(tf.tensordot(emb_wo,self.theta_weight,[[1,2],[0,1]]))
        self.theta_point=alpha/(tf.expand_dims(tf.reduce_sum(alpha,-1),-1)+1e-10)


        gamma =params["prior"]*tf.ones_like(alpha)

        pst_dist = tf.distributions.Dirichlet(alpha)
        pri_dist = tf.distributions.Dirichlet(gamma)

        '''kl_divergence for theta'''
        theta_kl_loss=pst_dist.kl_divergence(pri_dist)
        theta_kl_loss=tf.reduce_mean(theta_kl_loss,-1)
        self.theta=pst_dist.sample()        


    ''' Phi Matrix '''   
    with tf.name_scope("Phi"):   
      self.phi=tf.nn.dropout(tf.nn.softmax(tf.contrib.layers.batch_norm(tf.layers.dense(emb_wo,self.num_topics),-1)),inputs["dropout"])      
      # self.phi=tf.nn.dropout(tf.nn.softmax(tf.layers.dense(emb_wo,self.num_topics),-1),inputs["dropout"])

      self.phi=((1-stop_indicator)*self.phi)+((stop_indicator)*(1./self.num_topics))

        
    '''Token loss (Reconstruction Loss)'''
    with tf.name_scope("token_loss"):     
      h_prob=tf.expand_dims(tf.nn.softmax(tf.layers.dense(rnn_outputs, units=self.vocab_size, use_bias=False),-1),2)      
      b_prob=tf.expand_dims(tf.pad(tf.nn.softmax(tf.contrib.layers.batch_norm(self.beta),-1),self.paddings,"CONSTANT"),0)                                            
      token_logits = (1-(params["mixture_lambda"]*(1-tf.expand_dims(stop_indicator,-1))))*h_prob+params["mixture_lambda"]*tf.expand_dims(1-stop_indicator,-1)*b_prob
      token_loss=tf.log(tf.reduce_sum(target_to_onehot*token_logits,-1)+1e-4)
      token_loss=seq_mask*tf.reduce_sum(self.phi*token_loss,-1)
      token_loss = -tf.reduce_mean(tf.reduce_sum(token_loss, axis=-1))


    with tf.name_scope("indicator_loss"):         
      # indicator_logits = tf.squeeze(tf.layers.dense(rnn_outputs,  units=1,activation=tf.nn.softplus), axis=2)
      indicator_logits = tf.squeeze(tf.contrib.layers.batch_norm(tf.layers.dense(tf.layers.dense(rnn_outputs,  units=5,activation=tf.nn.softplus),units=1,activation=tf.nn.softplus)), axis=2)

      indicator_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.to_float(inputs["indicators"]),logits=indicator_logits,name="indicator_loss")
      indicator_loss=tf.reduce_mean(tf.reduce_sum(seq_mask*indicator_loss,-1))
      indicator_acc=tf.reduce_mean(tf.to_float(tf.equal(tf.round(tf.nn.sigmoid(indicator_logits)),tf.to_float(inputs["indicators"]))),-1)
      indicator_acc=tf.reduce_mean(indicator_acc)





    with tf.name_scope("Perplexity"):
        k_temp=tf.nn.sigmoid(indicator_logits)*tf.squeeze(tf.reduce_sum(target_to_onehot*h_prob,-1),-1)
        token_ppl=tf.exp(-tf.reduce_sum(seq_mask*tf.log(tf.reduce_sum(tf.expand_dims(1-tf.nn.sigmoid(indicator_logits),-1)*self.phi*(1-stop_indicator)*tf.reduce_sum(target_to_onehot*((1-params["mixture_lambda"])*h_prob+params["mixture_lambda"]*b_prob),-1),-1)+k_temp+1e-10))/(1e-10+tf.to_float(tf.reduce_sum(inputs["length"]))))


    with tf.name_scope("TextGenerate"):
      k_text_temp=tf.expand_dims(tf.nn.sigmoid(indicator_logits),-1)*tf.squeeze(h_prob,2)
      phi_text_temp=tf.reduce_sum(tf.expand_dims(tf.expand_dims(1-tf.nn.sigmoid(indicator_logits),-1)*self.phi*(1-stop_indicator),-1)*((1-params["mixture_lambda"])*h_prob+params["mixture_lambda"]*b_prob),2)
      # pred_next_token=tf.argmax(k_text_temp+phi_text_temp,-1)
      pred_next_token=dist.Categorical(probs=k_text_temp+phi_text_temp).sample()


    with tf.name_scope("TextGenerateTheta"):
      # theta_text_temp=tf.reduce_sum(tf.expand_dims(tf.expand_dims(1-tf.nn.sigmoid(indicator_logits),-1)*tf.expand_dims(self.theta,1),-1)*((1-params["mixture_lambda"])*h_prob+params["mixture_lambda"]*b_prob),2)
      theta_text_temp=tf.reduce_sum(tf.expand_dims(tf.expand_dims(1-tf.nn.sigmoid(indicator_logits),-1)*tf.expand_dims(self.theta_point,1),-1)*((1-params["mixture_lambda"])*h_prob+params["mixture_lambda"]*b_prob),2)

      # pred_next_token_theta=tf.argmax(k_text_temp+theta_text_temp,-1)
      pred_next_token_theta=dist.Categorical(probs=k_text_temp+theta_text_temp).sample()
      print('-'*50)
      print('pred_next_token_theta',pred_next_token_theta.get_shape())
      print('-'*50)

      # print('pred_next_token',pred_next_token.get_shape())
      # print('pred_next_token',pred_next_token.get_shape())
        # if inputs["model"]=="Valid":
        # all_next_probs=tf.reduce_sum(tf.expand_dims(1-tf.nn.sigmoid(indicator_logits),-1)*self.phi*(1-stop_indicator)*((1-params["mixture_lambda"])*h_prob+params["mixture_lambda"]*b_prob)+k_temp

        # k_temp=tf.nn.sigmoid(indicator_logits)
        # k_temp=tf.reduce_sum(target_to_onehot*h_prob,-1)


        # unif_temp=tf.reduce_sum(tf.expand_dims(tf.nn.sigmoid(indicator_logits),-1)
        # labels_temp=tf.reduce_sum(labels*tf.nn.softmax(indicator_logits),-1)
        # token_ppl=tf.exp(-tf.reduce_sum(seq_mask*tf.log(phi_temp*labels_temp+1e-10))/(1e-5+tf.to_float(tf.reduce_sum(inputs["length"]))))

          # ,-1)
        # +1e-10)
        # print('phi_temp',phi_temp.get_shape())
        # print('k_temp',k_temp.get_shape())

        # print('labels_temp',labels_temp.get_shape())
        # print('seq_mask',seq_mask.get_shape())






    ''' KL between Phi and theta '''
    with tf.name_scope("Phi_theta_kl"):
      theta=tf.expand_dims(self.theta,1)
      phi_theta_kl_loss=tf.reduce_mean(tf.reduce_sum(tf.squeeze(1-stop_indicator,-1)*tf.reduce_sum((1-stop_indicator)*self.phi*tf.log((((1-stop_indicator)*self.phi)/(theta+1e-10))+1e-10),-1),-1))      

    total_loss=token_loss+theta_kl_loss+indicator_loss+phi_theta_kl_loss

    with tf.name_scope("SwitchP"):
      all_topics=tf.argmax(self.phi,-1)

    with tf.name_scope("Entropies"):
      # phi_entropy=tf.reduce_mean(tf.reduce_sum(tf.to_float(1-inputs["indicators"])*tf.reduce_sum(-self.phi*tf.log(self.phi+1e-10),-1),-1)/tf.reduce_sum(tf.to_float(1-inputs["indicators"])),-1)      
      theta_entropy=tf.reduce_mean(tf.reduce_sum(-self.theta*tf.log(self.theta+1e-10),-1))      
      phi_entropy=tf.reduce_mean(tf.reduce_sum(tf.to_float(1-inputs["indicators"])*tf.reduce_sum(-self.phi*tf.log(self.phi+1e-10),-1),-1)/tf.reduce_sum(tf.to_float(1-inputs["indicators"])),-1)
      # print('-'*100)
      # print('theta_entropy',theta_entropy.get_shape())
      # print('-'*100)

      # all_topics=dist.Categorical(probs=self.phi).sample()
      # cat_topic=dist.Categorical(probs=self.theta)

      # all_topics=tf.transpose(cat_topic.sample(sample_shape=[self.phi.get_shape()[1]]))
      # print('all_topics',all_topics.get_shape())
      # print('-'*100)

      # all_topics=tf.self.phi




    tf.summary.scalar(tensor=token_loss, name=mode+" token_loss")
    tf.summary.scalar(tensor=phi_theta_kl_loss, name=mode+" phi_theta_kl_loss")    
    tf.summary.scalar(tensor=indicator_loss, name=mode+" indicator_loss")
    tf.summary.scalar(tensor=theta_kl_loss, name=mode+" theta_kl_loss")
    tf.summary.scalar(tensor=total_loss, name=mode+" total_loss")
    tf.summary.scalar(tensor=token_ppl, name=mode+" token_ppl")

    outputs = {
        "token_loss": token_loss,
        "token_ppl": token_ppl,
        "indicator_loss": indicator_loss,
        "theta_kl_loss": theta_kl_loss,
        "phi_theta_kl_loss": phi_theta_kl_loss,
        "loss": total_loss,
        "theta": self.theta,
        "repre": final_output[-1][1],
        "beta":self.beta,
        "all_topics": all_topics,
        "non_stop_indic":1-inputs["indicators"],
        "phi":self.phi,
        "pred_next_token":pred_next_token,
        "accuracy":indicator_acc,
        "pred_next_token_theta":pred_next_token_theta,
        "theta_entropy":theta_entropy,
        # "phi_entropy":phi_entropy
        }
    return outputs