def _log_prob(self, data, num_samples=1):
        """Assumes data is [batch_size] + data_dim."""
        batch_size = tf.shape(data)[0]
        log_Z = tf.log_sigmoid(self.logit_Z)  # pylint: disable=invalid-name
        log_1mZ = -self.logit_Z + log_Z  # pylint: disable=invalid-name

        # [B]
        data_log_accept = tf.squeeze(tf.log_sigmoid(
            self.logit_accept_fn(data)),
                                     axis=-1)
        truncated_geometric_log_probs = tf.range(self.T - 1,
                                                 dtype=self.dtype) * log_1mZ
        # [B, T-1]
        truncated_geometric_log_probs = (
            truncated_geometric_log_probs[None, :] + data_log_accept[:, None])
        # [B, T]
        truncated_geometric_log_probs = tf.concat([
            truncated_geometric_log_probs,
            tf.tile((self.T - 1) * log_1mZ[None, None], [batch_size, 1])
        ],
                                                  axis=-1)
        truncated_geometric_log_probs -= tf.reduce_logsumexp(
            truncated_geometric_log_probs, axis=-1, keepdims=True)

        # [B]
        entropy = -tf.reduce_sum(tf.exp(truncated_geometric_log_probs) *
                                 truncated_geometric_log_probs,
                                 axis=-1)

        proposal_samples = self.proposal.sample([self.T])  # [T] + data_dim
        proposal_logit_accept = self.logit_accept_fn(proposal_samples)
        proposal_log_reject = tf.reduce_mean(
            -proposal_logit_accept + tf.log_sigmoid(proposal_logit_accept))

        # [B]
        noise_term = tf.reduce_sum(
            tf.exp(truncated_geometric_log_probs) *
            tf.range(self.T, dtype=self.dtype)[None, :] * proposal_log_reject,
            axis=-1)

        try:
            # Try giving the proposal lower bound num_samples if it can use it.
            log_prob_proposal = self.proposal.log_prob(data,
                                                       num_samples=num_samples)
        except TypeError:
            log_prob_proposal = self.proposal.log_prob(data)
        elbo = log_prob_proposal + data_log_accept + noise_term + entropy

        return elbo
        def generation(h):
            with tf.variable_scope("generation", reuse=tf.AUTO_REUSE):
                with tf.variable_scope("decoder"):
                    decoder_init_state = h
                    decoder_cell = tf.nn.rnn_cell.GRUCell(args.rnn_size)
                    # decoder_outputs.shape=(128, None, 256)
                    decoder_outputs, _ = tf.nn.dynamic_rnn(
                        decoder_cell,
                        decoder_inputs_embedded,
                        initial_state=decoder_init_state,
                        sequence_length=seq_length,
                        dtype=tf.float32,
                    )
                with tf.variable_scope("outputs"):
                    out_w = tf.get_variable(
                        "out_w", [out_size, args.rnn_size], tf.float32,
                        tf.random_normal_initializer(stddev=0.02))
                    out_b = tf.get_variable(
                        "out_b", [out_size],
                        tf.float32,
                        initializer=tf.constant_initializer(0.0))

                    batch_rec_loss = tf.reduce_mean(
                        decoder_mask * tf.reshape(
                            tf.nn.sampled_softmax_loss(
                                weights=out_w,
                                biases=out_b,
                                labels=tf.reshape(decoder_targets,
                                                  [-1, 1]),  # shape=(None, 1)
                                inputs=tf.reshape(
                                    decoder_outputs,
                                    [-1, args.rnn_size]),  # shape=(None, 256)
                                num_sampled=args.neg_size,
                                num_classes=out_size),
                            [args.batch_size, -1]),
                        axis=-1)
                    target_out_w = tf.nn.embedding_lookup(
                        out_w, decoder_targets)  # shape=(128, None, 256)
                    target_out_b = tf.nn.embedding_lookup(
                        out_b, decoder_targets)  # shape=(128, None)
                    batch_likelihood = tf.reduce_mean(
                        decoder_mask * tf.log_sigmoid(
                            tf.reduce_sum(decoder_outputs * target_out_w, -1) +
                            target_out_b),
                        axis=-1,
                        name="batch_likelihood")

                    batch_latent_loss = 0.5 * tf.reduce_sum(
                        att * tf.reduce_mean(
                            stack_log_sigma_sq_c + tf.exp(stack_log_sigma_sq_z)
                            / tf.exp(stack_log_sigma_sq_c) +
                            tf.square(stack_mu_z - stack_mu_c) /
                            tf.exp(stack_log_sigma_sq_c),
                            axis=-1),
                        axis=-1) - 0.5 * tf.reduce_mean(1 + log_sigma_sq_z,
                                                        axis=-1)
                    batch_cate_loss = tf.reduce_mean(
                        tf.reduce_mean(att, axis=0) *
                        tf.log(tf.reduce_mean(att, axis=0)))
                return batch_rec_loss, batch_latent_loss, batch_cate_loss, batch_likelihood
Example #3
0
def DenseShiftLogScale(x,
                       output_units,
                       h=None,
                       hidden_layers=[],
                       activation=tf.nn.relu,
                       log_scale_clip=None,
                       train=False,
                       dropout_rate=0.0,
                       sigmoid_scale=False,
                       log_scale_factor=1.0,
                       log_scale_reg=0.0,
                       *args,
                       **kwargs):
  for units in hidden_layers:
    x = tf.layers.dense(
        inputs=x, units=units, activation=activation, *args, **kwargs)
    if h is not None:
      x += tf.layers.dense(h, units, use_bias=False, *args, **kwargs)
    if dropout_rate > 0:
      x = tf.layers.dropout(x, dropout_rate, training=train)
  if log_scale_factor == 1.0 and log_scale_reg == 0.0:
    x = tf.layers.dense(inputs=x, units=2 * output_units, *args, **kwargs)
    if h is not None:
      x += tf.layers.dense(h, 2 * output_units, use_bias=False, *args, **kwargs)

    shift, log_scale = tf.split(x, 2, axis=-1)
  else:
    shift = tf.layers.dense(h, output_units, *args, **kwargs)
    if log_scale_reg > 0.0:
      regularizer = lambda w: log_scale_reg * 2.0 * tf.nn.l2_loss(w)
    else:
      regularizer = None
    log_scale = tf.layers.dense(
        h,
        output_units,
        use_bias=False,
        kernel_regularizer=regularizer,
        *args,
        **kwargs)
    log_scale = log_scale * log_scale_factor + tf.get_variable(
        "log_scale_bias", [1, output_units], initializer=tf.zeros_initializer())
    if h is not None:
      shift += tf.layers.dense(h, output_units, use_bias=False, *args, **kwargs)
      log_scale += tf.layers.dense(
          h, output_units, use_bias=False, *args, **kwargs)

  if sigmoid_scale:
    log_scale = tf.log_sigmoid(log_scale)

  if log_scale_clip:
    log_scale = log_scale_clip * tf.nn.tanh(log_scale / log_scale_clip)

  return shift, log_scale
Example #4
0
    def _build_model(self):
        self.graph_built = True
        tf.set_random_seed(self.seed)
        self.labels = tf.placeholder(tf.float32, shape=[None])
        self.is_training = tf.placeholder_with_default(False, shape=[])
        self._build_variables()
        self._build_user_embeddings()
        if self.task == "rating" or self.loss_type == "cross_entropy":
            self.user_indices = tf.placeholder(tf.int32, shape=[None])
            self.item_indices = tf.placeholder(tf.int32, shape=[None])

            item_embed = tf.nn.embedding_lookup(
                self.item_weights, self.item_indices
            )
            item_bias = tf.nn.embedding_lookup(
                self.item_biases, self.item_indices
            )
            self.output = tf.reduce_sum(
                tf.multiply(self.user_embed, item_embed), axis=1
            ) + item_bias

        elif self.loss_type == "bpr":
            self.item_indices_pos = tf.placeholder(tf.int32, shape=[None])
            self.item_indices_neg = tf.placeholder(tf.int32, shape=[None])
            item_embed_pos = tf.nn.embedding_lookup(
                self.item_weights, self.item_indices_pos
            )
            item_embed_neg = tf.nn.embedding_lookup(
                self.item_weights, self.item_indices_neg
            )
            item_bias_pos = tf.nn.embedding_lookup(
                self.item_biases, self.item_indices_pos
            )
            item_bias_neg = tf.nn.embedding_lookup(
                self.item_biases, self.item_indices_neg
            )

            item_diff = tf.subtract(item_bias_pos,
                                    item_bias_neg) + tf.reduce_sum(
                tf.multiply(
                    self.user_embed,
                    tf.subtract(item_embed_pos, item_embed_neg)
                ), axis=1
            )
            self.log_sigmoid = tf.log_sigmoid(item_diff)

        count_params()
Example #5
0
    def _build_model_tf(self):
        self.graph_built = True
        if isinstance(self.reg, float) and self.reg > 0.0:
            tf_reg = tf.keras.regularizers.l2(self.reg)
        else:
            tf_reg = None

        self.user_indices = tf.placeholder(tf.int32, shape=[None])
        self.item_indices_pos = tf.placeholder(tf.int32, shape=[None])
        self.item_indices_neg = tf.placeholder(tf.int32, shape=[None])

        self.item_bias_var = tf.get_variable(name="item_bias_var",
                                             shape=[self.n_items + 1],
                                             initializer=tf_zeros,
                                             regularizer=tf_reg)
        self.user_embed_var = tf.get_variable(
            name="user_embed_var",
            shape=[self.n_users + 1, self.embed_size],
            initializer=tf_truncated_normal(0.0, 0.03),
            regularizer=tf_reg)
        self.item_embed_var = tf.get_variable(
            name="item_embed_var",
            shape=[self.n_items + 1, self.embed_size],
            initializer=tf_truncated_normal(0.0, 0.03),
            regularizer=tf_reg)

        bias_item_pos = tf.nn.embedding_lookup(self.item_bias_var,
                                               self.item_indices_pos)
        bias_item_neg = tf.nn.embedding_lookup(self.item_bias_var,
                                               self.item_indices_neg)
        embed_user = tf.nn.embedding_lookup(self.user_embed_var,
                                            self.user_indices)
        embed_item_pos = tf.nn.embedding_lookup(self.item_embed_var,
                                                self.item_indices_pos)
        embed_item_neg = tf.nn.embedding_lookup(self.item_embed_var,
                                                self.item_indices_neg)

        item_diff = tf.subtract(
            bias_item_pos, bias_item_neg) + tf.reduce_sum(tf.multiply(
                embed_user, tf.subtract(embed_item_pos, embed_item_neg)),
                                                          axis=1)
        self.log_sigmoid = tf.log_sigmoid(item_diff)
    def __init__(
            self,  # pylint: disable=invalid-name
            T,
            data_dim,
            logit_accept_fn,
            proposal=None,
            dtype=tf.float32,
            name="rejection_sampling"):
        """Creates a Rejection Sampling model.

    Args:
      T: The maximum number of proposals to sample in the rejection sampler.
      data_dim: The dimension of the data. Should be a list.
      logit_accept_fn: Accept function, takes [batch_size] + data_dim to [0, 1].
      proposal: A distribution over the data space of this model. Must support
        sample() and log_prob() although log_prob only needs to return a lower
        bound on the true log probability. If not supplied, then defaults to
        Gaussian.
      dtype: Type of data.
      name: Name to use in scopes.
    """
        self.T = T  # pylint: disable=invalid-name
        self.data_dim = data_dim
        self.logit_accept_fn = logit_accept_fn
        if proposal is None:
            self.proposal = base.get_independent_normal(data_dim)
        else:
            self.proposal = proposal
        self.name = name
        self.dtype = dtype
        with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
            self.logit_Z = tf.get_variable(  # pylint: disable=invalid-name
                name="logit_Z",
                shape=[],
                dtype=dtype,
                initializer=tf.constant_initializer(0.),
                trainable=True)

        tf.summary.scalar("Expected_trials",
                          tf.exp(-tf.log_sigmoid(self.logit_Z)))
Example #7
0
def kl_loss(src_emb, pos_emb, neg_emb, sim_fn="dot"):
  """kl loss used for line.
  Args:
    src_emb: tensor with shape [batch_size, dim]
    pos_emb: tensor with shape [batch_size, dim]
    neg_emb: tensor with shape [batch_size * neg_num, dim]
    sim_fn: similarity measure function, cosine, euclidean and dot.
  Returns:
    loss, logit, label
  """

  if sim_fn == "cosine":
    sim_function = _rank_cosine_distance
  elif sim_fn == "euclidean":
    sim_function = _rank_euclidean_distance
  elif sim_fn == "dot":
    sim_function = _rank_dot_product
  else:
    print("not support %s similarity measure function"%(sim_fn))
    raise Exception

  emb_dim = tf.shape(src_emb)[1]
  batch_size = tf.shape(src_emb)[0]
  per_sample_neg_num = tf.shape(neg_emb)[0] / batch_size
  pos_inner_product = sim_function(src_emb, pos_emb)

  src_emb_exp = tf.tile(tf.expand_dims(src_emb, axis=1),
                        [1, per_sample_neg_num, 1])
  src_emb_exp = tf.reshape(src_emb_exp, [-1, emb_dim])
  neg_inner_product = tf.reduce_sum(tf.multiply(src_emb_exp, neg_emb), axis=-1)

  logits = tf.concat([pos_inner_product, neg_inner_product], axis=0)
  labels = tf.concat([tf.ones_like(pos_inner_product),
                      -1 * tf.ones_like(neg_inner_product)], axis=0)

  loss = -tf.reduce_mean(tf.log_sigmoid(logits * labels))

  return [loss, logits, labels]
Example #8
0
    def __init__(self, args):
        self.args = args
        dense = tf.layers.dense

        # inputs/mask.shape=(128, None)  'None' in shape means any number  seq_length.shape=(128,)
        inputs = tf.placeholder(shape=(args.batch_size, None),
                                dtype=tf.int32,
                                name='inputs')
        time_inputs = tf.placeholder(shape=(args.batch_size, None),
                                     dtype=tf.int32,
                                     name='time_inputs')
        mask = tf.placeholder(shape=(args.batch_size, None),
                              dtype=tf.float32,
                              name='inputs_mask')
        seq_length = tf.placeholder(shape=args.batch_size,
                                    dtype=tf.float32,
                                    name='seq_length')

        self.input_form = [inputs, time_inputs, mask, seq_length]

        # all shape=(128, None)
        encoder_inputs = inputs
        decoder_inputs = tf.concat(
            [tf.zeros(shape=(args.batch_size, 1), dtype=tf.int32), inputs],
            axis=1)
        decoder_targets = tf.concat(
            [inputs,
             tf.zeros(shape=(args.batch_size, 1), dtype=tf.int32)],
            axis=1)
        decoder_mask = tf.concat(
            [mask,
             tf.zeros(shape=(args.batch_size, 1), dtype=tf.float32)],
            axis=1)

        x_size = out_size = args.map_size[0] * args.map_size[1]
        # embeddings.shape=(16900, 32)  tf.random_uniform(shape, minval=0, maxval=None, ...)
        # x_latent_size is the input embedding size = 32
        embeddings = tf.Variable(tf.random_uniform(
            [x_size, args.x_latent_size], -1.0, 1.0),
                                 dtype=tf.float32)
        # tf.nn.embedding_lookup(params, ids, ...)  Looks up ids in a list of embedding tensors.
        # shape=(128, None, 32)
        encoder_inputs_embedded = tf.nn.embedding_lookup(
            embeddings, encoder_inputs)
        decoder_inputs_embedded = tf.nn.embedding_lookup(
            embeddings, decoder_inputs)

        time_embeddings = tf.Variable(tf.random_uniform(
            [49, args.x_latent_size], -1.0, 1.0),
                                      dtype=tf.float32)
        encoder_time_inputs_embedded = tf.nn.embedding_lookup(
            time_embeddings, time_inputs)

        time_mean = tf.reduce_mean(encoder_time_inputs_embedded, axis=1)
        mu_delta = dense(time_mean, args.rnn_size, activation=None)
        log_sigma_sq_delta = dense(time_mean, args.rnn_size, activation=None)

        with tf.variable_scope("encoder"):
            # create a GRUCell  output_size = state_size = 256
            encoder_cell = tf.nn.rnn_cell.GRUCell(args.rnn_size)

            # tf.compat.v1.nn.dynamic_rnn(cell, inputs, ...) = keras.layers.RNN(cell)
            # returns (outputs, state)
            # 'outputs' is a tensor of shape [batch_size, max_time, cell_output_size]
            # 'state' is a tensor of shape [batch_size, cell_state_size] = (128, 256)
            _, encoder_final_state = tf.nn.dynamic_rnn(
                encoder_cell,
                encoder_inputs_embedded,
                sequence_length=seq_length,
                dtype=tf.float32,
            )

        # tf.compat.v1.get_variable(name, shape=None, dtype=None,
        #                           initializer=None, ...)
        mu_w = tf.get_variable("mu_w", [args.rnn_size, args.rnn_size],
                               tf.float32,
                               tf.random_normal_initializer(stddev=0.02))
        mu_b = tf.get_variable("mu_b", [args.rnn_size], tf.float32,
                               tf.constant_initializer(0.0))
        sigma_w = tf.get_variable("sigma_w", [args.rnn_size, args.rnn_size],
                                  tf.float32,
                                  tf.random_normal_initializer(stddev=0.02))
        sigma_b = tf.get_variable("sigma_b", [args.rnn_size], tf.float32,
                                  tf.constant_initializer(0.0))

        # all shape=(128, 256)
        mu = tf.matmul(encoder_final_state, mu_w) + mu_b + mu_delta
        log_sigma_sq = tf.matmul(encoder_final_state,
                                 sigma_w) + sigma_b + log_sigma_sq_delta
        eps = tf.random_normal(shape=tf.shape(log_sigma_sq),
                               mean=0,
                               stddev=1,
                               dtype=tf.float32)

        if args.eval:
            # z = tf.zeros(shape=(args.batch_size, args.rnn_size), dtype=tf.float32)
            z = mu_delta
        else:
            # Re-parameterization trick
            z = mu + tf.sqrt(tf.exp(log_sigma_sq)) * eps

        self.batch_post_embedded = z

        with tf.variable_scope("decoder"):
            decoder_cell = tf.nn.rnn_cell.GRUCell(args.rnn_size)
            decoder_init_state = z
            decoder_outputs, _ = tf.nn.dynamic_rnn(
                decoder_cell,
                decoder_inputs_embedded,
                initial_state=decoder_init_state,
                sequence_length=seq_length,
                dtype=tf.float32,
            )

        # out_size = 16900
        out_w = tf.get_variable("out_w", [out_size, args.rnn_size], tf.float32,
                                tf.random_normal_initializer(stddev=0.02))
        out_b = tf.get_variable("out_b", [out_size], tf.float32,
                                tf.constant_initializer(0.0))
        # tf.reduce_mean(input_tensor, axis=None, ...)  Reduces input_tensor to mean value along the given axis.
        # tf.reshape(tensor, shape, name=None)  Reshape the tensor into given shape, -1 indicates calculated value.
        # tf.nn.sampled_softmax_loss()  A fast way to train softmax classifier, usually an underestimate (for training only).
        batch_rec_loss = tf.reduce_mean(
            decoder_mask * tf.reshape(
                tf.nn.sampled_softmax_loss(
                    weights=out_w,
                    biases=out_b,
                    labels=tf.reshape(decoder_targets, [-1, 1]),
                    inputs=tf.reshape(decoder_outputs, [-1, args.rnn_size]),
                    num_sampled=args.neg_size,
                    num_classes=out_size), [args.batch_size, -1]),
            axis=-1  # reduce to mean along the last dimension
        )
        batch_latent_loss = -0.5 * tf.reduce_sum(
            1 + log_sigma_sq - tf.square(mu) - tf.exp(log_sigma_sq), axis=1)

        self.rec_loss = rec_loss = tf.reduce_mean(batch_rec_loss)
        self.latent_loss = latent_loss = tf.reduce_mean(batch_latent_loss)

        self.loss = loss = tf.reduce_mean([rec_loss, latent_loss])
        self.train_op = tf.train.AdamOptimizer(
            args.learning_rate).minimize(loss)

        target_out_w = tf.nn.embedding_lookup(out_w, decoder_targets)
        target_out_b = tf.nn.embedding_lookup(out_b, decoder_targets)

        self.batch_likelihood = tf.reduce_mean(decoder_mask * tf.log_sigmoid(
            tf.reduce_sum(decoder_outputs * target_out_w, -1) + target_out_b),
                                               axis=-1,
                                               name="batch_likelihood")

        # save/restore variables to/from checkpoints, max_to_keep = max #recent checkpoint files to keep.
        saver = tf.train.Saver(tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES),
                               max_to_keep=10)
        self.save, self.restore = saver.save, saver.restore
Example #9
0
def DenseAR(x,
            h=None,
            hidden_layers=[],
            activation=tf.nn.relu,
            log_scale_clip=None,
            log_scale_clip_pre=None,
            train=False,
            dropout_rate=0.0,
            sigmoid_scale=False,
            log_scale_factor=1.0,
            log_scale_reg=0.0,
            shift_only=False,
            *args,
            **kwargs):
    input_depth = x.shape.with_rank_at_least(1)[-1].value
    if input_depth is None:
        raise NotImplementedError(
            "Rightmost dimension must be known prior to graph execution.")
    input_shape = (np.int32(x.shape.as_list())
                   if x.shape.is_fully_defined() else tf.shape(x))
    for i, units in enumerate(hidden_layers):
        x = tfb.masked_dense(inputs=x,
                             units=units,
                             num_blocks=input_depth,
                             exclusive=True if i == 0 else False,
                             activation=activation,
                             *args,
                             **kwargs)
        if h is not None:
            x += tf.layers.dense(h, units, use_bias=False, *args, **kwargs)
        if dropout_rate > 0:
            x = tf.layers.dropout(x, dropout_rate, training=train)

    if shift_only:
        shift = tfb.masked_dense(inputs=x,
                                 units=input_depth,
                                 num_blocks=input_depth,
                                 activation=None,
                                 *args,
                                 **kwargs)
        return shift, None
    else:
        if log_scale_factor == 1.0 and log_scale_reg == 0.0 and not log_scale_clip_pre:
            x = tfb.masked_dense(inputs=x,
                                 units=2 * input_depth,
                                 num_blocks=input_depth,
                                 activation=None,
                                 *args,
                                 **kwargs)
            if h is not None:
                x += tf.layers.dense(h,
                                     2 * input_depth,
                                     use_bias=False,
                                     *args,
                                     **kwargs)
            x = tf.reshape(x, shape=tf.concat([input_shape, [2]], axis=0))
            shift, log_scale = tf.unstack(x, num=2, axis=-1)
        else:
            shift = tfb.masked_dense(inputs=x,
                                     units=input_depth,
                                     num_blocks=input_depth,
                                     activation=None,
                                     *args,
                                     **kwargs)
            if log_scale_reg > 0.0:
                regularizer = lambda w: log_scale_reg * 2.0 * tf.nn.l2_loss(w)
            else:
                regularizer = None
            log_scale = tfb.masked_dense(inputs=x,
                                         units=input_depth,
                                         num_blocks=input_depth,
                                         activation=None,
                                         use_bias=False,
                                         kernel_regularizer=regularizer,
                                         *args,
                                         **kwargs)
            log_scale *= log_scale_factor
            if log_scale_clip_pre:
                log_scale = log_scale_clip_pre * tf.nn.tanh(
                    log_scale / log_scale_clip_pre)
            log_scale += tf.get_variable("log_scale_bias", [1, input_depth],
                                         initializer=tf.zeros_initializer())
            if h is not None:
                shift += tf.layers.dense(h,
                                         input_depth,
                                         use_bias=False,
                                         *args,
                                         **kwargs)
                log_scale += tf.layers.dense(h,
                                             input_depth,
                                             use_bias=False,
                                             *args,
                                             **kwargs)

        if sigmoid_scale:
            log_scale = tf.log_sigmoid(log_scale)

        if log_scale_clip:
            log_scale = log_scale_clip * tf.nn.tanh(log_scale / log_scale_clip)

        return shift, log_scale
Example #10
0
def main(unused_argv):
    g = tf.Graph()
    with g.as_default():
        target = dists.get_target_distribution(
            FLAGS.target,
            nine_gaussians_variance=FLAGS.nine_gaussians_variance)
        energy_fn_layers = [
            int(x.strip()) for x in FLAGS.energy_fn_sizes.split(",")
        ]
        if FLAGS.algo == "lars":
            print("Running LARS")
            loss, train_op, global_step = make_lars_graph(
                target_dist=target,
                K=FLAGS.K,
                batch_size=FLAGS.batch_size,
                eval_batch_size=FLAGS.eval_batch_size,
                lr=FLAGS.learning_rate,
                mlp_layers=energy_fn_layers,
                dtype=tf.float32)
        else:
            proposal = base.get_independent_normal([2],
                                                   FLAGS.proposal_variance)
            if FLAGS.algo == "nis":
                print("Running NIS")
                model = nis.NIS(K=FLAGS.K,
                                data_dim=[2],
                                energy_hidden_sizes=energy_fn_layers,
                                proposal=proposal)
                density_image_summary(
                    lambda x:  # pylint: disable=g-long-lambda
                    (tf.squeeze(model.energy_fn(x)) + model.proposal.log_prob(
                        x)),
                    FLAGS.density_num_bins,
                    "energy/nis")
            elif FLAGS.algo == "rejection_sampling":
                print("Running Rejection Sampling")
                model = rejection_sampling.RejectionSampling(
                    T=FLAGS.K,
                    data_dim=[2],
                    energy_hidden_sizes=energy_fn_layers,
                    proposal=proposal)
                density_image_summary(
                    lambda x: tf.squeeze(  # pylint: disable=g-long-lambda
                        tf.log_sigmoid(model.logit_accept_fn(x)),
                        axis=-1) + model.proposal.log_prob(x),
                    FLAGS.density_num_bins,
                    "energy/trs")
            elif FLAGS.algo == "his":
                print("Running HIS")
                model = his.FullyConnectedHIS(
                    T=FLAGS.his_t,
                    data_dim=[2],
                    energy_hidden_sizes=energy_fn_layers,
                    q_hidden_sizes=energy_fn_layers,
                    init_step_size=FLAGS.his_stepsize,
                    learn_stepsize=FLAGS.his_learn_stepsize,
                    init_alpha=FLAGS.his_alpha,
                    learn_temps=FLAGS.his_learn_alpha,
                    proposal=proposal)
                density_image_summary(
                    lambda x: -model.hamiltonian_potential(x),
                    FLAGS.density_num_bins, "energy/his")
                sample_image_summary(model,
                                     "density/his",
                                     num_samples=100000,
                                     num_bins=50)

            loss, train_op, global_step = make_train_graph(
                target_dist=target,
                model=model,
                batch_size=FLAGS.batch_size,
                eval_batch_size=FLAGS.eval_batch_size,
                lr=FLAGS.learning_rate)

        log_hooks = make_log_hooks(global_step, loss)
        with tf.train.MonitoredTrainingSession(
                master="",
                is_chief=True,
                hooks=log_hooks,
                checkpoint_dir=os.path.join(FLAGS.logdir, exp_name()),
                save_checkpoint_secs=120,
                save_summaries_steps=FLAGS.summarize_every,
                log_step_count_steps=FLAGS.summarize_every) as sess:
            cur_step = -1
            while True:
                if sess.should_stop() or cur_step > FLAGS.max_steps:
                    break
                # run a step
                _, cur_step = sess.run([train_op, global_step])