Example #1
0
def isemhash_bottleneck(x,
                        bottleneck_size,
                        bottleneck_noise,
                        discretize_warmup_steps,
                        mode,
                        isemhash_noise_dev=0.5,
                        isemhash_mix_prob=0.5):
    """Improved semantic hashing bottleneck."""
    with tf.variable_scope("isemhash_bottleneck"):
        x = tf.layers.dense(x, bottleneck_size, name="dense")
        y = common_layers.saturating_sigmoid(x)
        if isemhash_noise_dev > 0 and mode == tf.estimator.ModeKeys.TRAIN:
            noise = tf.truncated_normal(common_layers.shape_list(x),
                                        mean=0.0,
                                        stddev=isemhash_noise_dev)
            y = common_layers.saturating_sigmoid(x + noise)
        d = tf.to_float(tf.less(0.5, y)) + y - tf.stop_gradient(y)
        d = 2.0 * d - 1.0  # Move from [0, 1] to [-1, 1].
        if mode == tf.estimator.ModeKeys.TRAIN:  # Flip some bits.
            noise = tf.random_uniform(common_layers.shape_list(x))
            noise = 2.0 * tf.to_float(tf.less(bottleneck_noise, noise)) - 1.0
            d *= noise
            d = common_layers.mix(d,
                                  2.0 * y - 1.0,
                                  discretize_warmup_steps,
                                  mode == tf.estimator.ModeKeys.TRAIN,
                                  max_prob=isemhash_mix_prob)
        return d
Example #2
0
def bottleneck(x, hparams, filter_size, name):
  """Bottleneck."""
  def embed1(x):
    if hparams.bottleneck_kind == "semhash":
      c = int_to_bit(x, c_size)
      h1a = tf.layers.dense(c, filter_size, name="vch1a")
      h1b = tf.layers.dense(1.0 - c, filter_size, name="vch1b")
      return h1a + h1b
    elif hparams.bottleneck_kind == "gumbel-softmax":
      hot = tf.one_hot(x, hparams.v_size)
      with tf.variable_scope(name, reuse=True):
        return tf.layers.dense(hot, hparams.hidden_size, name="dae_dense")

  def embed(x):
    with tf.variable_scope(name, reuse=True):
      h1 = embed1(x)
      h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2")
      res = tf.layers.dense(tf.nn.relu(h2), hparams.hidden_size, name="vcfin")
    return res

  with tf.variable_scope(name):
    c_size = hparams.c_size
    l = tf.constant(0.0)
    if hparams.bottleneck_kind == "dense":
      c = tf.layers.dense(x, c_size, name="vcc")
      h1 = tf.layers.dense(c, filter_size, name="vch1")
    if hparams.bottleneck_kind == "semhash":
      c = tf.layers.dense(x, c_size, name="vcc")
      y_clean = common_layers.saturating_sigmoid(c)
      tf.summary.histogram("y_clean", tf.reshape(y_clean, [-1]))
      # l = tf.reduce_mean(y_clean * (1.0 - y_clean))
      if hparams.noise_dev > 0 and hparams.mode == tf.estimator.ModeKeys.TRAIN:
        dev = hparams.noise_dev
        noise = tf.truncated_normal(tf.shape(c), mean=0.0, stddev=dev)
        y = common_layers.saturating_sigmoid(c + noise)
      else:
        y = y_clean
      d = tf.to_float(tf.less(0.5, y))
      y_discrete = tf.stop_gradient(d) + y - tf.stop_gradient(y)
      pd = common_layers.inverse_exp_decay(hparams.startup_steps * 2)
      pd *= hparams.d_mix
      pd = pd if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
      c = tf.cond(tf.less(tf.random_uniform([]), pd),
                  lambda: y_discrete, lambda: y)
      h1a = tf.layers.dense(c, filter_size, name="vch1a")
      h1b = tf.layers.dense(1.0 - c, filter_size, name="vch1b")
      h1 = h1a + h1b
      dx = tf.to_int32(tf.stop_gradient(d))
      c = bit_to_int(dx, c_size)
    if hparams.bottleneck_kind == "gumbel-softmax":
      _, hot, l = dae(x, hparams, name)
      c = tf.argmax(hot, axis=-1)
      h1 = tf.layers.dense(hot, hparams.hidden_size, name="dae_dense")
    h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2")
    res = tf.layers.dense(tf.nn.relu(h2), hparams.hidden_size, name="vcfin")
    return res, c, l, embed
Example #3
0
def bit_vae(x, hparams, name):
    with tf.variable_scope(name):
        bity = tf.layers.dense(x, hparams.z_size, name="bity")
        dev = common_layers.inverse_lin_decay(hparams.startup_steps) * 1.5
        noise = tf.random_normal(tf.shape(bity), mean=0.0, stddev=dev)
        y = common_layers.saturating_sigmoid(bity + noise)
        tf.summary.histogram("bit", tf.reshape(y, [-1]))

        def discrete_y():
            d = tf.to_float(tf.less(0.5, y))
            return tf.stop_gradient(d) + y - tf.stop_gradient(y)

        y = tf.cond(tf.less(tf.train.get_global_step(), hparams.startup_steps),
                    lambda: y, discrete_y)
        # Flatten and predict for loss.
        y_flat = tf.reshape(y, [-1, hparams.z_size, 1, 1])
        hsize = hparams.hidden_size
        hparams.hidden_size = hsize // 2
        emb0 = tf.get_variable("emb0", [hparams.hidden_size])
        emb1 = tf.get_variable("emb1", [hparams.hidden_size])
        emb0 = tf.reshape(emb0, [1, 1, 1, hparams.hidden_size])
        emb1 = tf.reshape(emb0, [1, 1, 1, hparams.hidden_size])
        y_emb = y_flat * emb1 + (1 - y_flat) * emb0
        y_logit = decode(None, None, y_emb, None, None, hparams, "dbit")
        hparams.hidden_size = hsize
        y_pred = tf.nn.log_softmax(tf.layers.dense(y_logit, 2, name="y_pred"))
        y_flat = tf.reshape(y_flat, [-1])
        y_pred = tf.reshape(y_pred, [-1, 2])
        loss = -(y_flat * y_pred[:, 1] + (1 - y_flat) * y_pred[:, 0])
        # Get the final z and return.
        z = tf.layers.dense(y, hparams.z_size, name="after_bit")
        return z, tf.reduce_mean(loss)
Example #4
0
def isemhash_bottleneck(x, bottleneck_bits, bottleneck_noise,
                        discretize_warmup_steps, mode,
                        isemhash_noise_dev=0.5, isemhash_mix_prob=0.5):
  """Improved semantic hashing bottleneck."""
  with tf.variable_scope("isemhash_bottleneck"):
    x = tf.layers.dense(x, bottleneck_bits, name="dense")
    y = common_layers.saturating_sigmoid(x)
    if isemhash_noise_dev > 0 and mode == tf.estimator.ModeKeys.TRAIN:
      noise = tf.truncated_normal(
          common_layers.shape_list(x), mean=0.0, stddev=isemhash_noise_dev)
      y = common_layers.saturating_sigmoid(x + noise)
    d = tf.to_float(tf.less(0.5, y)) + y - tf.stop_gradient(y)
    d = 2.0 * d - 1.0  # Move from [0, 1] to [-1, 1].
    if mode == tf.estimator.ModeKeys.TRAIN:  # Flip some bits.
      noise = tf.random_uniform(common_layers.shape_list(x))
      noise = 2.0 * tf.to_float(tf.less(bottleneck_noise, noise)) - 1.0
      d *= noise
      d = common_layers.mix(d, 2.0 * y - 1.0, discretize_warmup_steps,
                            mode == tf.estimator.ModeKeys.TRAIN,
                            max_prob=isemhash_mix_prob)
    return d, 0.0
 def testSaturatingSigmoid(self):
   x = np.array([-120.0, -100.0, 0.0, 100.0, 120.0], dtype=np.float32)
   y = common_layers.saturating_sigmoid(tf.constant(x))
   res = self.evaluate(y)
   self.assertAllClose(res, [0.0, 0.0, 0.5, 1.0, 1.0])
Example #6
0
def discrete_bottleneck(x,
                        hidden_size,
                        z_size,
                        filter_size,
                        name,
                        mode=None,
                        startup_steps=50000,
                        bottleneck_kind="dvq",
                        num_blocks=2,
                        num_residuals=1,
                        reshape_method="slice",
                        projection_tensors=None,
                        means=None,
                        beta=0.25,
                        noise_dev=1.,
                        decay=0.999,
                        discrete_mix=0.5,
                        random_top_k=1,
                        soft_em=False,
                        num_samples=1,
                        epsilon=1e-5,
                        softmax_k=0,
                        kl_warmup_steps=150000,
                        ema=True,
                        ema_count=None,
                        ema_means=None,
                        summary=True):
    """Discretization bottleneck for latent variables.

  Args:
    x: Input to the discretization bottleneck.
    hidden_size: Dimension of the latent state.
    z_size: Number of bits used to produce discrete code; discrete codes range
      from 1 to 2**z_size.
    filter_size: Filter size to be used for the embedding function.
    name: Name for the bottleneck scope.
    mode: Mode represents whether we are training or testing for bottlenecks
      that differ in behavior (Default: None).
    startup_steps: Number of steps after which latent predictor is trained
      (Default: 50000).
    bottleneck_kind: Kind of discretization bottleneck to use; one of dvq,
      semhash, gumbel-softmax (Default: dvq).
    num_blocks: Number of blocks to use for decomposed vector
      quantization (Default: 2).
    num_residuals: Number of residual units used to compute nearest
      neighbors (Default: 1).
    reshape_method: Method to reshape for DVQ (Default: slice).
    projection_tensors: If the reshape method is project, then these are the
      tensors used to project (Default: None).
    means: The embedding table for dvq (Default: None).
    beta: Beta factor for the DVQ loss (Default: 0.25).
    noise_dev: Stddev for noise added for semhash (Default: 0).
    decay: Decay factor for the exponential moving average (Default: 0.999).
    discrete_mix: Factor for mixing discrete and non-discrete input for semhash
      (Default: 0.5).
    random_top_k: Noisy top-k for DVQ (Default: 1).
    soft_em: If True then use soft EM rather than hard EM (Default: False).
    num_samples: Number of samples for soft EM (Default: 1).
    epsilon: Epsilon parameter for DVQ (Default: 1e-5).
    softmax_k: If > 1 then do top-k softmax (Default: 0).
    kl_warmup_steps: Number of steps for kl warmup (Default: 150000).
    ema: If True update embeddings using exponential moving averages (Default:
      True).
    ema_count: Table of counts for each embedding corresponding to how many
      examples in a batch it was the closest to (Default: None).
    ema_means: Exponentially averaged version of the embeddings (Default: None).
    summary: If True, then write summaries (Default: True).

  Returns:
    Embedding to pass to the decoder, discrete latent, loss, and the embedding
    function.

  Raises:
    ValueError: If projection_tensors is None for reshape_method project, or
    ema_count or ema_means is None if we are using ema, or unknown args.
  """
    tf.logging.info("Shape of x = {}".format(common_layers.shape_list(x)))
    block_v_size = None
    if bottleneck_kind == "dvq":
        # Define the dvq parameters
        assert means is not None

        # Check block dimensions add up
        if hidden_size % num_blocks != 0:
            raise ValueError("num_blocks does not divide hidden size")

        if z_size % num_residuals != 0:
            raise ValueError(
                "num_residuals does not divide embedding table size")

        z_size_per_residual = int(z_size / num_residuals)

        if z_size_per_residual % num_blocks != 0:
            raise ValueError("num_blocks does not divide embedding table size")

        block_v_size = 2**(z_size_per_residual / num_blocks)
        block_v_size = int(block_v_size)

        # Set the reshape method corresponding to projections or slices
        if reshape_method == "slice":
            reshape_fn = partial(slice_hidden,
                                 hidden_size=hidden_size,
                                 num_blocks=num_blocks)
        elif reshape_method == "project":
            if projection_tensors is None:
                raise ValueError(
                    "Projection tensors is None for reshape_method project")
            reshape_fn = partial(project_hidden,
                                 projection_tensors=projection_tensors,
                                 hidden_size=hidden_size,
                                 num_blocks=num_blocks)
        else:
            raise ValueError("Unknown reshape_method")

        # Check if the ema settings make sense
        if ema:
            if ema_count is None:
                raise ValueError("ema_count is None but ema is True")
            if ema_means is None:
                raise ValueError("ema_means is None but ema is True")

    with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
        l = tf.constant(0.0)
        if bottleneck_kind == "dense":
            c = tf.layers.dense(x, z_size, name="vcc")
            h1 = tf.layers.dense(c, filter_size, name="vch1")
        elif bottleneck_kind == "vae":
            c, l, _, _ = vae(x, z_size, "vae")
            h1 = tf.layers.dense(c, filter_size, name="vch1")
        elif bottleneck_kind == "semhash":
            c = tf.layers.dense(x, z_size, name="vcc")
            y_clean = common_layers.saturating_sigmoid(c)
            if summary:
                tf.summary.histogram("y_clean", tf.reshape(y_clean, [-1]))
            if noise_dev > 0 and mode == tf.estimator.ModeKeys.TRAIN:
                noise = tf.truncated_normal(common_layers.shape_list(c),
                                            mean=0.0,
                                            stddev=noise_dev)
                y = common_layers.saturating_sigmoid(c + noise)
            else:
                y = y_clean
            d = tf.to_float(tf.less(0.5, y))
            y_discrete = tf.stop_gradient(d) + y - tf.stop_gradient(y)
            pd = common_layers.inverse_exp_decay(startup_steps * 2)
            pd *= discrete_mix
            pd = pd if mode == tf.estimator.ModeKeys.TRAIN else 1.0
            c = tf.where(
                tf.less(tf.random_uniform([common_layers.shape_list(y)[0]]),
                        pd), y_discrete, y)
            h1a = tf.layers.dense(c, filter_size, name="vch1a")
            h1b = tf.layers.dense(1.0 - c, filter_size, name="vch1b")
            h1 = h1a + h1b
            dx = tf.to_int32(tf.stop_gradient(d))
            c = bit_to_int(dx, z_size)
        elif bottleneck_kind == "gumbel-softmax":
            _, hot, l = gumbel_softmax(x, name, z_size, mode, softmax_k,
                                       kl_warmup_steps, summary)
            c = tf.argmax(hot, axis=-1)
            h1 = tf.layers.dense(hot, hidden_size, name="dae_dense")
        elif bottleneck_kind == "dvq":
            x_reshaped = reshape_fn(x)
            x_res = x_reshaped
            x_means_hot = []
            x_means = 0
            l = 0
            for i in range(num_residuals):
                x_means_hot_res, x_means_res, q_loss_res, e_loss_res = embedding_lookup(
                    x_res, means[i], num_blocks, block_v_size, random_top_k,
                    soft_em, num_samples)

                # Update the ema variables
                if ema:
                    tf.logging.info("Using EMA with beta = {}".format(beta))
                    updated_ema_count_res = moving_averages.assign_moving_average(
                        ema_count[i],
                        tf.reduce_sum(tf.reshape(
                            x_means_hot_res,
                            shape=[-1, num_blocks, block_v_size]),
                                      axis=0),
                        decay,
                        zero_debias=False)

                    dw = tf.matmul(
                        tf.transpose(x_means_hot_res, perm=[1, 2, 0]),
                        tf.transpose(x_res, perm=[1, 0, 2]))

                    updated_ema_means_res = moving_averages.assign_moving_average(
                        ema_means[i], dw, decay, zero_debias=False)
                    n = tf.reduce_sum(updated_ema_count_res,
                                      axis=-1,
                                      keep_dims=True)
                    updated_ema_count_res = (
                        (updated_ema_count_res + epsilon) /
                        (n + 2**z_size * epsilon) * n)
                    updated_ema_means_res /= tf.expand_dims(
                        updated_ema_count_res, axis=-1)

                    with tf.control_dependencies([e_loss_res]):
                        update_means_res = tf.assign(means[i],
                                                     updated_ema_means_res)
                        with tf.control_dependencies([update_means_res]):
                            l += beta * e_loss_res
                else:
                    l += q_loss_res + beta * e_loss_res

                # Update the residuals
                x_res -= x_means_res
                x_means += x_means_res
                x_means_hot.append(x_means_hot_res)

            # Get the discrete latent representation
            x_means_hot = tf.stack(x_means_hot, axis=1)
            x_means_idx = tf.argmax(x_means_hot, axis=-1)

            # Get the binary representation
            x_means_bits = int_to_bit(
                x_means_idx,
                num_bits=int(z_size / (num_residuals * num_blocks)),
                base=2)
            shape = common_layers.shape_list(x_means_bits)
            new_shape = shape[:-2]
            new_shape[-1] = z_size
            x_means_bits = tf.reshape(x_means_bits, shape=new_shape)
            c = bit_to_int(tf.to_int32(x_means_bits), num_bits=z_size, base=2)

            # Adjust shape of c
            shape_x = common_layers.shape_list(x)
            new_shape = shape_x[:-1]
            c = tf.reshape(c, new_shape)

            x_means = tf.reshape(x_means, shape_x)
            x_reshaped = tf.reshape(x_reshaped, shape_x)
            h1 = x_reshaped + tf.stop_gradient(x_means - x_reshaped)
        else:
            raise ValueError("Unknown discretization method.")

        h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2")
        res = tf.layers.dense(tf.nn.relu(h2), hidden_size, name="vcfin")

        embed_fn = partial(embed,
                           hidden_size=hidden_size,
                           z_size=z_size,
                           filter_size=filter_size,
                           name=name,
                           bottleneck_kind=bottleneck_kind,
                           num_blocks=num_blocks,
                           num_residuals=num_residuals,
                           block_v_size=block_v_size,
                           means=means)
        return res, c, l, embed_fn
Example #7
0
 def testSaturatingSigmoid(self):
     x = np.array([-120.0, -100.0, 0.0, 100.0, 120.0], dtype=np.float32)
     with self.test_session() as session:
         y = common_layers.saturating_sigmoid(tf.constant(x))
         res = session.run(y)
     self.assertAllClose(res, [0.0, 0.0, 0.5, 1.0, 1.0])
Example #8
0
 def testSaturatingSigmoid(self):
     x = np.array([-120.0, -100.0, 0.0, 100.0, 120.0], dtype=np.float32)
     y = common_layers.saturating_sigmoid(tf.constant(x))
     res = self.evaluate(y)
     self.assertAllClose(res, [0.0, 0.0, 0.5, 1.0, 1.0])
def bottleneck(x,
               hparams,
               filter_size,
               name,
               means=None,
               ema_count=None,
               ema_means=None):
  """Bottleneck."""
  if hparams.bottleneck_kind == "vq-vae":
    assert means is not None
    if hparams.ema:
      assert ema_count is not None
      assert ema_means is not None

  def embed(x):
    """Embedding function; must be compatible with the code later."""
    with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
      if hparams.bottleneck_kind == "semhash":
        c = int_to_bit(x, z_size)
        h1a = tf.layers.dense(c, filter_size, name="vch1a")
        h1b = tf.layers.dense(1.0 - c, filter_size, name="vch1b")
        h1 = h1a + h1b
      elif hparams.bottleneck_kind == "gumbel-softmax":
        hot = tf.one_hot(x, hparams.v_size)
        h1 = tf.layers.dense(hot, hparams.hidden_size, name="dae_dense")
      elif hparams.bottleneck_kind == "vq-vae":
        if hparams.ema:
          means_embed = ema_means
        else:
          means_embed = means

        h1 = tf.gather(means_embed, x)
      elif hparams.bottleneck_kind == "rounding":
        h1 = x

      h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2")
      return tf.layers.dense(tf.nn.relu(h2), hparams.hidden_size, name="vcfin")

  with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
    z_size = hparams.z_size
    l = tf.constant(0.0)
    if hparams.bottleneck_kind == "dense":
      c = tf.layers.dense(x, z_size, name="vcc")
      h1 = tf.layers.dense(c, filter_size, name="vch1")
    if hparams.bottleneck_kind == "vae":
      c, l, _, _ = vae(x, z_size, "vae")
      h1 = tf.layers.dense(c, filter_size, name="vch1")
    if hparams.bottleneck_kind == "semhash":
      c = tf.layers.dense(x, z_size, name="vcc")
      y_clean = common_layers.saturating_sigmoid(c)
      if _DO_SUMMARIES:
        tf.summary.histogram("y_clean", tf.reshape(y_clean, [-1]))
      if hparams.noise_dev > 0 and hparams.mode == tf.estimator.ModeKeys.TRAIN:
        dev = hparams.noise_dev
        noise = tf.truncated_normal(common_layers.shape_list(c),
                                    mean=0.0, stddev=dev)
        y = common_layers.saturating_sigmoid(c + noise)
      else:
        y = y_clean
      d = tf.to_float(tf.less(0.5, y))
      y_discrete = tf.stop_gradient(d) + y - tf.stop_gradient(y)
      pd = common_layers.inverse_exp_decay(hparams.startup_steps * 2)
      pd *= hparams.d_mix
      pd = pd if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
      c = tf.where(tf.less(tf.random_uniform(
          [common_layers.shape_list(y)[0]]), pd), y_discrete, y)
      h1a = tf.layers.dense(c, filter_size, name="vch1a")
      h1b = tf.layers.dense(1.0 - c, filter_size, name="vch1b")
      h1 = h1a + h1b
      dx = tf.to_int32(tf.stop_gradient(d))
      c = bit_to_int(dx, z_size)
    if hparams.bottleneck_kind == "gumbel-softmax":
      _, hot, l = dae(x, hparams, name)
      c = tf.argmax(hot, axis=-1)
      h1 = tf.layers.dense(hot, hparams.hidden_size, name="dae_dense")
    if hparams.bottleneck_kind == "vq-vae":
      x_means_hot, x_means, q_loss, e_loss = kmeans(x, means, hparams)
      c = tf.argmax(x_means_hot, axis=-1)

      # Update the ema variables
      if hparams.ema:
        tf.logging.info("Using EMA with beta = {}".format(hparams.beta))
        x_means_hot_flat = tf.reshape(x_means_hot, shape=[-1, hparams.v_size])
        updated_ema_count = moving_averages.assign_moving_average(
            ema_count,
            tf.reduce_sum(x_means_hot_flat, axis=0),
            hparams.decay,
            zero_debias=False)
        x_flat = tf.reshape(x, [-1, hparams.hidden_size])
        dw = tf.matmul(x_means_hot_flat, x_flat, transpose_a=True)
        updated_ema_means = moving_averages.assign_moving_average(
            ema_means, dw, hparams.decay, zero_debias=False)
        n = tf.reduce_sum(updated_ema_count)
        updated_ema_count = ((updated_ema_count + hparams.epsilon) /
                             (n + hparams.v_size * hparams.epsilon) * n)
        updated_ema_means /= tf.expand_dims(updated_ema_count, axis=-1)

        with tf.control_dependencies([e_loss]):
          update_means = tf.assign(means, updated_ema_means)
          with tf.control_dependencies([update_means]):
            l = hparams.beta * e_loss
      else:
        l = q_loss + hparams.beta * e_loss

      h1 = tf.stop_gradient(x_means) + x - tf.stop_gradient(x)

    if hparams.bottleneck_kind == "rounding":
      h = tf.layers.dense(x, 1, name="vcc")

      # Make h between 0 and 1
      h = tf.sigmoid(h)

      # Multiply by z_size to get it between [0, z_size]
      h *= hparams.v_size

      # Use the rounding bottleneck
      h1 = h + tf.stop_gradient(tf.round(h) - h)
      c = tf.squeeze(tf.round(h), axis=-1)
      c = tf.to_int32(c)
    h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2")
    res = tf.layers.dense(tf.nn.relu(h2), hparams.hidden_size, name="vcfin")
    return res, c, l, embed
def bottleneck(x, hparams, filter_size, name):
  """Bottleneck."""
  def embed(x):
    """Embedding function; must be compatible with the code later."""
    with tf.variable_scope(name, reuse=True):
      if hparams.bottleneck_kind == "semhash":
        c = int_to_bit(x, z_size)
        h1a = tf.layers.dense(c, filter_size, name="vch1a")
        h1b = tf.layers.dense(1.0 - c, filter_size, name="vch1b")
        h1 = h1a + h1b
      elif hparams.bottleneck_kind == "gumbel-softmax":
        hot = tf.one_hot(x, hparams.v_size)
        h1 = tf.layers.dense(hot, hparams.hidden_size, name="dae_dense")
      elif hparams.bottleneck_kind == "vq-vae":
        means = tf.get_variable(name="means",
                                shape=[hparams.v_size, hparams.hidden_size])
        h1 = tf.gather(means, x)
      elif hparams.bottleneck_kind == "rounding":
        h1 = x

      h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2")
      return tf.layers.dense(tf.nn.relu(h2), hparams.hidden_size, name="vcfin")

  with tf.variable_scope(name):
    z_size = hparams.z_size
    l = tf.constant(0.0)
    if hparams.bottleneck_kind == "dense":
      c = tf.layers.dense(x, z_size, name="vcc")
      h1 = tf.layers.dense(c, filter_size, name="vch1")
    if hparams.bottleneck_kind == "vae":
      c, l, _, _ = vae(x, z_size, "vae")
      h1 = tf.layers.dense(c, filter_size, name="vch1")
    if hparams.bottleneck_kind == "semhash":
      c = tf.layers.dense(x, z_size, name="vcc")
      y_clean = common_layers.saturating_sigmoid(c)
      if _DO_SUMMARIES:
        tf.summary.histogram("y_clean", tf.reshape(y_clean, [-1]))
      if hparams.noise_dev > 0 and hparams.mode == tf.estimator.ModeKeys.TRAIN:
        dev = hparams.noise_dev
        noise = tf.truncated_normal(common_layers.shape_list(c),
                                    mean=0.0, stddev=dev)
        y = common_layers.saturating_sigmoid(c + noise)
      else:
        y = y_clean
      d = tf.to_float(tf.less(0.5, y))
      y_discrete = tf.stop_gradient(d) + y - tf.stop_gradient(y)
      pd = common_layers.inverse_exp_decay(hparams.startup_steps * 2)
      pd *= hparams.d_mix
      pd = pd if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
      c = tf.where(tf.less(tf.random_uniform(
          [common_layers.shape_list(y)[0]]), pd), y_discrete, y)
      h1a = tf.layers.dense(c, filter_size, name="vch1a")
      h1b = tf.layers.dense(1.0 - c, filter_size, name="vch1b")
      h1 = h1a + h1b
      dx = tf.to_int32(tf.stop_gradient(d))
      c = bit_to_int(dx, z_size)
    if hparams.bottleneck_kind == "gumbel-softmax":
      _, hot, l = dae(x, hparams, name)
      c = tf.argmax(hot, axis=-1)
      h1 = tf.layers.dense(hot, hparams.hidden_size, name="dae_dense")
    if hparams.bottleneck_kind == "vq-vae":
      means = tf.get_variable(name="means", shape=[hparams.v_size,
                                                   hparams.hidden_size])
      x_means_hot, x_means, l = kmeans(x, means, hparams, name="vq-vae-kmeans")
      h1 = tf.stop_gradient(x_means) + x - tf.stop_gradient(x)
      c = tf.argmax(x_means_hot, axis=-1)
    if hparams.bottleneck_kind == "rounding":
      h = tf.layers.dense(x, 1, name="vcc")

      # Make h between 0 and 1
      h = tf.sigmoid(h)

      # Multiply by z_size to get it between [0, z_size]
      h *= hparams.v_size

      # Use the rounding bottleneck
      h1 = h + tf.stop_gradient(tf.round(h) - h)
      c = tf.squeeze(tf.round(h), axis=-1)
      c = tf.to_int32(c)
    h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2")
    res = tf.layers.dense(tf.nn.relu(h2), hparams.hidden_size, name="vcfin")
    return res, c, l, embed
Example #11
0
def discrete_bottleneck(x,
                        hidden_size,
                        z_size,
                        filter_size,
                        name,
                        mode=None,
                        startup_steps=50000,
                        bottleneck_kind="dvq",
                        num_blocks=2,
                        num_residuals=1,
                        reshape_method="slice",
                        projection_tensors=None,
                        means=None,
                        beta=0.25,
                        noise_dev=1.,
                        decay=0.999,
                        discrete_mix=0.5,
                        random_top_k=1,
                        soft_em=False,
                        num_samples=1,
                        epsilon=1e-5,
                        softmax_k=0,
                        kl_warmup_steps=150000,
                        ema=True,
                        ema_count=None,
                        ema_means=None,
                        summary=True):
  """Discretization bottleneck for latent variables.

  Args:
    x: Input to the discretization bottleneck.
    hidden_size: Dimension of the latent state.
    z_size: Number of bits used to produce discrete code; discrete codes range
      from 1 to 2**z_size.
    filter_size: Filter size to be used for the embedding function.
    name: Name for the bottleneck scope.
    mode: Mode represents whether we are training or testing for bottlenecks
      that differ in behavior (Default: None).
    startup_steps: Number of steps after which latent predictor is trained
      (Default: 50000).
    bottleneck_kind: Kind of discretization bottleneck to use; one of dvq,
      semhash, gumbel-softmax (Default: dvq).
    num_blocks: Number of blocks to use for decomposed vector
      quantization (Default: 2).
    num_residuals: Number of residual units used to compute nearest
      neighbors (Default: 1).
    reshape_method: Method to reshape for DVQ (Default: slice).
    projection_tensors: If the reshape method is project, then these are the
      tensors used to project (Default: None).
    means: The embedding table for dvq (Default: None).
    beta: Beta factor for the DVQ loss (Default: 0.25).
    noise_dev: Stddev for noise added for semhash (Default: 0).
    decay: Decay factor for the exponential moving average (Default: 0.999).
    discrete_mix: Factor for mixing discrete and non-discrete input for semhash
      (Default: 0.5).
    random_top_k: Noisy top-k for DVQ (Default: 1).
    soft_em: If True then use soft EM rather than hard EM (Default: False).
    num_samples: Number of samples for soft EM (Default: 1).
    epsilon: Epsilon parameter for DVQ (Default: 1e-5).
    softmax_k: If > 1 then do top-k softmax (Default: 0).
    kl_warmup_steps: Number of steps for kl warmup (Default: 150000).
    ema: If True update embeddings using exponential moving averages (Default:
      True).
    ema_count: Table of counts for each embedding corresponding to how many
      examples in a batch it was the closest to (Default: None).
    ema_means: Exponentially averaged version of the embeddings (Default: None).
    summary: If True, then write summaries (Default: True).

  Returns:
    Embedding to pass to the decoder, discrete latent, loss, and the embedding
    function.

  Raises:
    ValueError: If projection_tensors is None for reshape_method project, or
    ema_count or ema_means is None if we are using ema, or unknown args.
  """
  block_v_size = None
  if bottleneck_kind == "dvq":
    # Define the dvq parameters
    assert means is not None

    # Check block dimensions add up
    if hidden_size % num_blocks != 0:
      raise ValueError("num_blocks does not divide hidden size")

    if z_size % num_residuals != 0:
      raise ValueError("num_residuals does not divide embedding table size")

    z_size_per_residual = int(z_size / num_residuals)

    if z_size_per_residual % num_blocks != 0:
      raise ValueError("num_blocks does not divide embedding table size")

    block_v_size = 2**(z_size_per_residual / num_blocks)
    block_v_size = int(block_v_size)

    # Set the reshape method corresponding to projections or slices
    if reshape_method == "slice":
      reshape_fn = partial(
          slice_hidden, hidden_size=hidden_size, num_blocks=num_blocks)
    elif reshape_method == "project":
      if projection_tensors is None:
        raise ValueError(
            "Projection tensors is None for reshape_method project")
      reshape_fn = partial(
          project_hidden,
          projection_tensors=projection_tensors,
          hidden_size=hidden_size,
          num_blocks=num_blocks)
    else:
      raise ValueError("Unknown reshape_method")

    # Check if the ema settings make sense
    if ema:
      if ema_count is None:
        raise ValueError("ema_count is None but ema is True")
      if ema_means is None:
        raise ValueError("ema_means is None but ema is True")

  with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
    l = tf.constant(0.0)
    if bottleneck_kind == "dense":
      c = tf.layers.dense(x, z_size, name="vcc")
      h1 = tf.layers.dense(c, filter_size, name="vch1")
    elif bottleneck_kind == "vae":
      c, l, _, _ = vae(x, z_size, "vae")
      h1 = tf.layers.dense(c, filter_size, name="vch1")
    elif bottleneck_kind == "semhash":
      c = tf.layers.dense(x, z_size, name="vcc")
      y_clean = common_layers.saturating_sigmoid(c)
      if summary:
        tf.summary.histogram("y_clean", tf.reshape(y_clean, [-1]))
      if noise_dev > 0 and mode == tf.estimator.ModeKeys.TRAIN:
        noise = tf.truncated_normal(
            common_layers.shape_list(c), mean=0.0, stddev=noise_dev)
        y = common_layers.saturating_sigmoid(c + noise)
      else:
        y = y_clean
      d = tf.to_float(tf.less(0.5, y))
      y_discrete = tf.stop_gradient(d) + y - tf.stop_gradient(y)
      pd = common_layers.inverse_exp_decay(startup_steps * 2)
      pd *= discrete_mix
      pd = pd if mode == tf.estimator.ModeKeys.TRAIN else 1.0
      c = tf.where(
          tf.less(tf.random_uniform([common_layers.shape_list(y)[0]]), pd),
          y_discrete, y)
      h1a = tf.layers.dense(c, filter_size, name="vch1a")
      h1b = tf.layers.dense(1.0 - c, filter_size, name="vch1b")
      h1 = h1a + h1b
      dx = tf.to_int32(tf.stop_gradient(d))
      c = bit_to_int(dx, z_size)
    elif bottleneck_kind == "gumbel-softmax":
      _, hot, l = gumbel_softmax(x, name, z_size, mode, softmax_k,
                                 kl_warmup_steps, summary)
      c = tf.argmax(hot, axis=-1)
      h1 = tf.layers.dense(hot, hidden_size, name="dae_dense")
    elif bottleneck_kind == "dvq":
      x_reshaped = reshape_fn(x)
      x_res = x_reshaped
      x_means_hot = []
      x_means = 0
      l = 0
      for i in range(num_residuals):
        x_means_hot_res, x_means_res, q_loss_res, e_loss_res = embedding_lookup(
            x_res, means[i], num_blocks, block_v_size, random_top_k, soft_em,
            num_samples)
        # Update the ema variables
        if ema:
          tf.logging.info("Using EMA with beta = {}".format(beta))
          updated_ema_count_res = moving_averages.assign_moving_average(
              ema_count[i],
              tf.reduce_sum(
                  tf.reshape(
                      x_means_hot_res, shape=[-1, num_blocks, block_v_size]),
                  axis=0),
              decay,
              zero_debias=False)

          dw = tf.matmul(
              tf.transpose(x_means_hot_res, perm=[1, 2, 0]),
              tf.transpose(x_res, perm=[1, 0, 2]))

          updated_ema_means_res = moving_averages.assign_moving_average(
              ema_means[i], dw, decay, zero_debias=False)
          n = tf.reduce_sum(updated_ema_count_res, axis=-1, keep_dims=True)
          updated_ema_count_res = ((updated_ema_count_res + epsilon) /
                                   (n + 2**z_size * epsilon) * n)
          # pylint: disable=g-no-augmented-assignment
          updated_ema_means_res = updated_ema_means_res / tf.expand_dims(
              updated_ema_count_res, axis=-1)
          # pylint: enable=g-no-augmented-assignment

          with tf.control_dependencies([e_loss_res]):
            update_means_res = tf.assign(means[i], updated_ema_means_res)
            with tf.control_dependencies([update_means_res]):
              l += beta * e_loss_res
        else:
          l += q_loss_res + beta * e_loss_res

        # Update the residuals
        x_res -= x_means_res
        x_means += x_means_res
        x_means_hot.append(x_means_hot_res)

      # Get the discrete latent representation
      x_means_hot = tf.stack(x_means_hot, axis=1)
      x_means_idx = tf.argmax(x_means_hot, axis=-1)

      # Get the binary representation
      x_means_bits = int_to_bit(
          x_means_idx,
          num_bits=int(z_size / (num_residuals * num_blocks)),
          base=2)
      shape = common_layers.shape_list(x_means_bits)
      new_shape = shape[:-2]
      new_shape[-1] = z_size
      x_means_bits = tf.reshape(x_means_bits, shape=new_shape)
      c = bit_to_int(tf.to_int32(x_means_bits), num_bits=z_size, base=2)

      # Adjust shape of c
      shape_x = common_layers.shape_list(x)
      new_shape = shape_x[:-1]
      c = tf.reshape(c, new_shape)

      # If we are doing soft EM then c is x_means_hot
      if soft_em:
        c = x_means_hot
        new_shape.append(block_v_size)
        c = tf.reshape(c, new_shape)

      x_means = tf.reshape(x_means, shape_x)
      x_reshaped = tf.reshape(x_reshaped, shape_x)
      h1 = x_reshaped + tf.stop_gradient(x_means - x_reshaped)
    else:
      raise ValueError("Unknown discretization method.")

    res = h1

    embed_fn = partial(
        embed,
        hidden_size=hidden_size,
        z_size=z_size,
        filter_size=filter_size,
        name=name,
        bottleneck_kind=bottleneck_kind,
        soft_em=soft_em,
        num_blocks=num_blocks,
        num_residuals=num_residuals,
        block_v_size=block_v_size,
        means=means)
    return res, c, l, embed_fn
Example #12
0
def discrete_bottleneck(x,
                        hidden_size,
                        z_size,
                        filter_size,
                        name,
                        mode=None,
                        startup_steps=50000,
                        bottleneck_kind='dvq',
                        num_blocks=2,
                        reshape_method='slice',
                        projection_tensors=None,
                        means=None,
                        beta=0.25,
                        noise_dev=1.,
                        decay=0.999,
                        discrete_mix=0.5,
                        random_top_k=1,
                        epsilon=1e-5,
                        softmax_k=0,
                        kl_warmup_steps=150000,
                        ema=True,
                        ema_count=None,
                        ema_means=None,
                        summary=True,
                        dp_strength=1.0,
                        dp_decay=1.0,
                        dp_alpha=0.5):
    """Discretization bottleneck for latent variables.

  Args:
    x: Input to the discretization bottleneck.
    hidden_size: Dimension of the latent state.
    z_size: Number of bits used to produce discrete code; discrete codes range
      from 1 to 2**z_size.
    filter_size: Filter size to be used for the embedding function.
    name: Name for the bottleneck scope.
    mode: Mode represents whether we are training or testing for bottlenecks
      that differ in behavior (Default: None).
    startup_steps: Number of steps after which latent predictor is trained
      (Default: 50000).
    bottleneck_kind: Kind of discretization bottleneck to use; one of dvq,
      semhash, gumbel-softmax (Default: dvq).
    num_blocks: Number of blocks to use for decomposed vector quantization.
    reshape_method: Method to reshape for DVQ (Default: slice).
    projection_tensors: If the reshape method is project, then these are the
      tensors used to project (Default: None).
    means: The embedding table for dvq (Default: None).
    beta: Beta factor for the DVQ loss (Default: 0.25).
    noise_dev: Stddev for noise added for semhash (Default: 0).
    decay: Decay factor for the exponential moving average (Default: 0.999).
    discrete_mix: Factor for mixing discrete and non-discrete input for semhash
      (Default: 0.5).
    random_top_k: Noisy top-k for DVQ (Default: 1).
    epsilon: Epsilon parameter for DVQ (Default: 1e-5).
    softmax_k: If > 1 then do top-k softmax (Default: 0).
    kl_warmup_steps: Number of steps for kl warmup (Default: 150000).
    ema: If True update embeddings using exponential moving averages (Default:
      True).
    ema_count: Table of counts for each embedding corresponding to how many
      examples in a batch it was the closest to (Default: None).
    ema_means: Exponentially averaged version of the embeddings (Default: None).
    summary: If True, then write summaries (Default: True).
    dp_strength: Strength of Dirichlet Process loss prior (Default: 1.0).
    dp_decay: Decay the dp_strength using an exponential decay using this
      term (Default: 1.0).
    dp_alpha: Alpha term (pseudo-count) in Dirichlet Process (Default: 0.5).

  Returns:
    Embedding to pass to the decoder, discrete latent, loss, and the embedding
    function.

  Raises:
    ValueError: If projection_tensors is None for reshape_method project, or
    ema_count or ema_means is None if we are using ema, or unknown args.
  """
    block_v_size = None
    if bottleneck_kind == 'dvq':
        # Define the dvq parameters
        assert means is not None

        # Check block dimensions add up
        if hidden_size % num_blocks != 0:
            raise ValueError('num_blocks does not divide hidden size')

        if 2**z_size % num_blocks != 0:
            raise ValueError('num_blocks does not divide embedding table size')

        block_v_size = 2**(z_size / num_blocks)
        block_v_size = int(block_v_size)

        # Set the reshape method corresponding to projections or slices
        if reshape_method == 'slice':
            reshape_fn = partial(slice_hidden,
                                 hidden_size=hidden_size,
                                 num_blocks=num_blocks)
        elif reshape_method == 'project':
            if projection_tensors is None:
                raise ValueError(
                    'Projection tensors is None for reshape_method project')
            reshape_fn = partial(project_hidden,
                                 projection_tensors=projection_tensors,
                                 hidden_size=hidden_size,
                                 num_blocks=num_blocks)
        else:
            raise ValueError('Unknown reshape_method')

        # Check if the ema settings make sense
        if ema:
            if ema_count is None:
                raise ValueError('ema_count is None but ema is True')
            if ema_means is None:
                raise ValueError('ema_means is None but ema is True')

    with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
        l = tf.constant(0.0)
        if bottleneck_kind == 'dense':
            c = tf.layers.dense(x, z_size, name='vcc')
            h1 = tf.layers.dense(c, filter_size, name='vch1')
        elif bottleneck_kind == 'vae':
            c, l, _, _ = vae(x, z_size, 'vae')
            h1 = tf.layers.dense(c, filter_size, name='vch1')
        elif bottleneck_kind == 'semhash':
            c = tf.layers.dense(x, z_size, name='vcc')
            y_clean = common_layers.saturating_sigmoid(c)
            if summary:
                tf.summary.histogram('y_clean', tf.reshape(y_clean, [-1]))
            if noise_dev > 0 and mode == tf.estimator.ModeKeys.TRAIN:
                noise = tf.truncated_normal(common_layers.shape_list(c),
                                            mean=0.0,
                                            stddev=noise_dev)
                y = common_layers.saturating_sigmoid(c + noise)
            else:
                y = y_clean
            d = tf.to_float(tf.less(0.5, y))
            y_discrete = tf.stop_gradient(d) + y - tf.stop_gradient(y)
            pd = common_layers.inverse_exp_decay(startup_steps * 2)
            pd *= discrete_mix
            pd = pd if mode == tf.estimator.ModeKeys.TRAIN else 1.0
            c = tf.where(
                tf.less(tf.random_uniform([common_layers.shape_list(y)[0]]),
                        pd), y_discrete, y)
            h1a = tf.layers.dense(c, filter_size, name='vch1a')
            h1b = tf.layers.dense(1.0 - c, filter_size, name='vch1b')
            h1 = h1a + h1b
            dx = tf.to_int32(tf.stop_gradient(d))
            c = bit_to_int(dx, z_size)
        elif bottleneck_kind == 'gumbel-softmax':
            _, hot, l = gumbel_softmax(x, name, z_size, mode, softmax_k,
                                       kl_warmup_steps, summary)
            c = tf.argmax(hot, axis=-1)
            h1 = tf.layers.dense(hot, hidden_size, name='dae_dense')
        elif bottleneck_kind == 'dvq':
            x_reshaped = reshape_fn(x)
            x_means_hot, x_means, q_loss, e_loss = embedding_lookup(
                x_reshaped, means, num_blocks, block_v_size, random_top_k)

            # Get the discrete latent represenation
            x_means_idx = tf.argmax(x_means_hot, axis=-1)

            # Get the binary representation
            x_means_bits = int_to_bit(x_means_idx,
                                      num_bits=int(z_size / num_blocks),
                                      base=2)
            shape = common_layers.shape_list(x_means_bits)
            new_shape = shape[:-1]
            new_shape[-1] = z_size
            x_means_bits = tf.reshape(x_means_bits, shape=new_shape)
            c = bit_to_int(tf.to_int32(x_means_bits), num_bits=z_size, base=2)

            # Adjust shape of c
            shape_x = common_layers.shape_list(x)
            new_shape = shape_x[:-1]
            c = tf.reshape(c, new_shape)

            # Update the ema variables
            if ema:
                tf.logging.info('Using EMA with beta = {}'.format(beta))
                updated_ema_count = moving_averages.assign_moving_average(
                    ema_count,
                    tf.reduce_sum(tf.reshape(
                        x_means_hot, shape=[-1, num_blocks, block_v_size]),
                                  axis=0),
                    decay,
                    zero_debias=False)

                # Adding a term that puts a Dirichlet prior over cluster probabilities
                # Hopefully it'll encourage rich get richer behaviors
                dp_prior_loss = 0.
                if dp_strength > 0.0:
                    # Decay dp_strength over time to make it less important
                    dp_strength = tf.train.exponential_decay(
                        dp_strength,
                        global_step=tf.to_int32(tf.train.get_global_step()),
                        decay_steps=20000,
                        decay_rate=dp_decay)
                    dp_count = ema_count + dp_alpha
                    p = dp_count / tf.reduce_sum(dp_count, 1, keepdims=True)
                    dp_prior_loss = tf.log(p)
                    dp_prior_loss = -1.0 * tf.reduce_sum(dp_prior_loss)
                    dp_prior_loss /= (num_blocks * block_v_size)

                x_means_hot_flat = tf.reshape(
                    x_means_hot, shape=[-1, num_blocks, block_v_size])
                dw = tf.matmul(tf.transpose(x_means_hot_flat, perm=[1, 2, 0]),
                               tf.transpose(x_reshaped, perm=[1, 0, 2]))
                updated_ema_means = moving_averages.assign_moving_average(
                    ema_means, dw, decay, zero_debias=False)
                n = tf.reduce_sum(updated_ema_count, axis=-1, keep_dims=True)
                updated_ema_count = ((updated_ema_count + epsilon) /
                                     (n + 2**z_size * epsilon) * n)
                updated_ema_means /= tf.expand_dims(updated_ema_count, axis=-1)

                with tf.control_dependencies([e_loss]):
                    update_means = tf.assign(means, updated_ema_means)
                    with tf.control_dependencies([update_means]):
                        l = beta * e_loss + dp_strength * dp_prior_loss
            else:
                l = q_loss + beta * e_loss

            x_means = tf.reshape(x_means, shape_x)
            x_reshaped = tf.reshape(x_reshaped, shape_x)
            h1 = x_reshaped + tf.stop_gradient(x_means - x_reshaped)
        else:
            raise ValueError('Unknown discretization method.')

        h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name='vch2')
        res = tf.layers.dense(tf.nn.relu(h2), hidden_size, name='vcfin')

        embed_fn = partial(embed,
                           hidden_size=hidden_size,
                           z_size=z_size,
                           filter_size=filter_size,
                           name=name,
                           bottleneck_kind=bottleneck_kind,
                           num_blocks=num_blocks,
                           block_v_size=block_v_size,
                           means=means)
        return res, c, l, embed_fn
Example #13
0
def discrete_bottleneck(x,
                        hidden_size,
                        z_size,
                        filter_size,
                        name,
                        mode=None,
                        startup_steps=50000,
                        bottleneck_kind='dvq',
                        num_blocks=2,
                        reshape_method='slice',
                        projection_tensors=None,
                        means=None,
                        beta=0.25,
                        noise_dev=1.,
                        decay=0.999,
                        discrete_mix=0.5,
                        random_top_k=1,
                        epsilon=1e-5,
                        softmax_k=0,
                        kl_warmup_steps=150000,
                        ema=True,
                        ema_count=None,
                        ema_means=None,
                        summary=True,
                        dp_strength=1.0,
                        dp_decay=1.0,
                        dp_alpha=0.5):
  """Discretization bottleneck for latent variables.

  Args:
    x: Input to the discretization bottleneck.
    hidden_size: Dimension of the latent state.
    z_size: Number of bits used to produce discrete code; discrete codes range
      from 1 to 2**z_size.
    filter_size: Filter size to be used for the embedding function.
    name: Name for the bottleneck scope.
    mode: Mode represents whether we are training or testing for bottlenecks
      that differ in behavior (Default: None).
    startup_steps: Number of steps after which latent predictor is trained
      (Default: 50000).
    bottleneck_kind: Kind of discretization bottleneck to use; one of dvq,
      semhash, gumbel-softmax (Default: dvq).
    num_blocks: Number of blocks to use for decomposed vector quantization.
    reshape_method: Method to reshape for DVQ (Default: slice).
    projection_tensors: If the reshape method is project, then these are the
      tensors used to project (Default: None).
    means: The embedding table for dvq (Default: None).
    beta: Beta factor for the DVQ loss (Default: 0.25).
    noise_dev: Stddev for noise added for semhash (Default: 0).
    decay: Decay factor for the exponential moving average (Default: 0.999).
    discrete_mix: Factor for mixing discrete and non-discrete input for semhash
      (Default: 0.5).
    random_top_k: Noisy top-k for DVQ (Default: 1).
    epsilon: Epsilon parameter for DVQ (Default: 1e-5).
    softmax_k: If > 1 then do top-k softmax (Default: 0).
    kl_warmup_steps: Number of steps for kl warmup (Default: 150000).
    ema: If True update embeddings using exponential moving averages (Default:
      True).
    ema_count: Table of counts for each embedding corresponding to how many
      examples in a batch it was the closest to (Default: None).
    ema_means: Exponentially averaged version of the embeddings (Default: None).
    summary: If True, then write summaries (Default: True).
    dp_strength: Strength of Dirichlet Process loss prior (Default: 1.0).
    dp_decay: Decay the dp_strength using an exponential decay using this
      term (Default: 1.0).
    dp_alpha: Alpha term (pseudo-count) in Dirichlet Process (Default: 0.5).

  Returns:
    Embedding to pass to the decoder, discrete latent, loss, and the embedding
    function.

  Raises:
    ValueError: If projection_tensors is None for reshape_method project, or
    ema_count or ema_means is None if we are using ema, or unknown args.
  """
  block_v_size = None
  if bottleneck_kind == 'dvq':
    # Define the dvq parameters
    assert means is not None

    # Check block dimensions add up
    if hidden_size % num_blocks != 0:
      raise ValueError('num_blocks does not divide hidden size')

    if 2**z_size % num_blocks != 0:
      raise ValueError('num_blocks does not divide embedding table size')

    block_v_size = 2**(z_size / num_blocks)
    block_v_size = int(block_v_size)

    # Set the reshape method corresponding to projections or slices
    if reshape_method == 'slice':
      reshape_fn = partial(
          slice_hidden, hidden_size=hidden_size, num_blocks=num_blocks)
    elif reshape_method == 'project':
      if projection_tensors is None:
        raise ValueError(
            'Projection tensors is None for reshape_method project')
      reshape_fn = partial(
          project_hidden,
          projection_tensors=projection_tensors,
          hidden_size=hidden_size,
          num_blocks=num_blocks)
    else:
      raise ValueError('Unknown reshape_method')

    # Check if the ema settings make sense
    if ema:
      if ema_count is None:
        raise ValueError('ema_count is None but ema is True')
      if ema_means is None:
        raise ValueError('ema_means is None but ema is True')

  with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
    l = tf.constant(0.0)
    if bottleneck_kind == 'dense':
      c = tf.layers.dense(x, z_size, name='vcc')
      h1 = tf.layers.dense(c, filter_size, name='vch1')
    elif bottleneck_kind == 'vae':
      c, l, _, _ = vae(x, z_size, 'vae')
      h1 = tf.layers.dense(c, filter_size, name='vch1')
    elif bottleneck_kind == 'semhash':
      c = tf.layers.dense(x, z_size, name='vcc')
      y_clean = common_layers.saturating_sigmoid(c)
      if summary:
        tf.summary.histogram('y_clean', tf.reshape(y_clean, [-1]))
      if noise_dev > 0 and mode == tf.estimator.ModeKeys.TRAIN:
        noise = tf.truncated_normal(
            common_layers.shape_list(c), mean=0.0, stddev=noise_dev)
        y = common_layers.saturating_sigmoid(c + noise)
      else:
        y = y_clean
      d = tf.to_float(tf.less(0.5, y))
      y_discrete = tf.stop_gradient(d) + y - tf.stop_gradient(y)
      pd = common_layers.inverse_exp_decay(startup_steps * 2)
      pd *= discrete_mix
      pd = pd if mode == tf.estimator.ModeKeys.TRAIN else 1.0
      c = tf.where(
          tf.less(tf.random_uniform([common_layers.shape_list(y)[0]]), pd),
          y_discrete, y)
      h1a = tf.layers.dense(c, filter_size, name='vch1a')
      h1b = tf.layers.dense(1.0 - c, filter_size, name='vch1b')
      h1 = h1a + h1b
      dx = tf.to_int32(tf.stop_gradient(d))
      c = bit_to_int(dx, z_size)
    elif bottleneck_kind == 'gumbel-softmax':
      _, hot, l = gumbel_softmax(x, name, z_size, mode, softmax_k,
                                 kl_warmup_steps, summary)
      c = tf.argmax(hot, axis=-1)
      h1 = tf.layers.dense(hot, hidden_size, name='dae_dense')
    elif bottleneck_kind == 'dvq':
      x_reshaped = reshape_fn(x)
      x_means_hot, x_means, q_loss, e_loss = embedding_lookup(
          x_reshaped, means, num_blocks, block_v_size, random_top_k)

      # Get the discrete latent represenation
      x_means_idx = tf.argmax(x_means_hot, axis=-1)

      # Get the binary representation
      x_means_bits = int_to_bit(
          x_means_idx, num_bits=int(z_size / num_blocks), base=2)
      shape = common_layers.shape_list(x_means_bits)
      new_shape = shape[:-1]
      new_shape[-1] = z_size
      x_means_bits = tf.reshape(x_means_bits, shape=new_shape)
      c = bit_to_int(tf.to_int32(x_means_bits), num_bits=z_size, base=2)

      # Adjust shape of c
      shape_x = common_layers.shape_list(x)
      new_shape = shape_x[:-1]
      c = tf.reshape(c, new_shape)

      # Update the ema variables
      if ema:
        tf.logging.info('Using EMA with beta = {}'.format(beta))
        updated_ema_count = moving_averages.assign_moving_average(
            ema_count,
            tf.reduce_sum(
                tf.reshape(x_means_hot, shape=[-1, num_blocks, block_v_size]),
                axis=0),
            decay,
            zero_debias=False)

        # Adding a term that puts a Dirichlet prior over cluster probabilities
        # Hopefully it'll encourage rich get richer behaviors
        dp_prior_loss = 0.
        if dp_strength > 0.0:
          # Decay dp_strength over time to make it less important
          dp_strength = tf.train.exponential_decay(
              dp_strength,
              global_step=tf.to_int32(tf.train.get_global_step()),
              decay_steps=20000,
              decay_rate=dp_decay)
          dp_count = ema_count + dp_alpha
          p = dp_count / tf.reduce_sum(dp_count, 1, keepdims=True)
          dp_prior_loss = tf.log(p)
          dp_prior_loss = -1.0 * tf.reduce_sum(dp_prior_loss)
          dp_prior_loss /= (num_blocks * block_v_size)

        x_means_hot_flat = tf.reshape(
            x_means_hot, shape=[-1, num_blocks, block_v_size])
        dw = tf.matmul(
            tf.transpose(x_means_hot_flat, perm=[1, 2, 0]),
            tf.transpose(x_reshaped, perm=[1, 0, 2]))
        updated_ema_means = moving_averages.assign_moving_average(
            ema_means, dw, decay, zero_debias=False)
        n = tf.reduce_sum(updated_ema_count, axis=-1, keep_dims=True)
        updated_ema_count = ((updated_ema_count + epsilon) /
                             (n + 2**z_size * epsilon) * n)
        updated_ema_means /= tf.expand_dims(updated_ema_count, axis=-1)

        with tf.control_dependencies([e_loss]):
          update_means = tf.assign(means, updated_ema_means)
          with tf.control_dependencies([update_means]):
            l = beta * e_loss + dp_strength * dp_prior_loss
      else:
        l = q_loss + beta * e_loss

      x_means = tf.reshape(x_means, shape_x)
      x_reshaped = tf.reshape(x_reshaped, shape_x)
      h1 = x_reshaped + tf.stop_gradient(x_means - x_reshaped)
    else:
      raise ValueError('Unknown discretization method.')

    h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name='vch2')
    res = tf.layers.dense(tf.nn.relu(h2), hidden_size, name='vcfin')

    embed_fn = partial(
        embed,
        hidden_size=hidden_size,
        z_size=z_size,
        filter_size=filter_size,
        name=name,
        bottleneck_kind=bottleneck_kind,
        num_blocks=num_blocks,
        block_v_size=block_v_size,
        means=means)
    return res, c, l, embed_fn
def bottleneck(x, hparams, filter_size, name):
    """Bottleneck."""
    def embed(x):
        """Embedding function; must be compatible with the code later."""
        with tf.variable_scope(name, reuse=True):
            if hparams.bottleneck_kind == "semhash":
                c = int_to_bit(x, z_size)
                h1a = tf.layers.dense(c, filter_size, name="vch1a")
                h1b = tf.layers.dense(1.0 - c, filter_size, name="vch1b")
                h1 = h1a + h1b
            elif hparams.bottleneck_kind == "gumbel-softmax":
                hot = tf.one_hot(x, hparams.v_size)
                h1 = tf.layers.dense(hot,
                                     hparams.hidden_size,
                                     name="dae_dense")
            elif hparams.bottleneck_kind == "vq-vae":
                means = tf.get_variable(
                    name="means", shape=[hparams.v_size, hparams.hidden_size])
                h1 = tf.gather(means, x)

            h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2")
            return tf.layers.dense(tf.nn.relu(h2),
                                   hparams.hidden_size,
                                   name="vcfin")

    with tf.variable_scope(name):
        z_size = hparams.z_size
        l = tf.constant(0.0)
        if hparams.bottleneck_kind == "dense":
            c = tf.layers.dense(x, z_size, name="vcc")
            h1 = tf.layers.dense(c, filter_size, name="vch1")
        if hparams.bottleneck_kind == "vae":
            c, l, _, _ = vae(x, z_size, "vae")
            h1 = tf.layers.dense(c, filter_size, name="vch1")
        if hparams.bottleneck_kind == "semhash":
            c = tf.layers.dense(x, z_size, name="vcc")
            y_clean = common_layers.saturating_sigmoid(c)
            if _DO_SUMMARIES:
                tf.summary.histogram("y_clean", tf.reshape(y_clean, [-1]))
            if hparams.noise_dev > 0 and hparams.mode == tf.estimator.ModeKeys.TRAIN:
                dev = hparams.noise_dev
                noise = tf.truncated_normal(tf.shape(c), mean=0.0, stddev=dev)
                y = common_layers.saturating_sigmoid(c + noise)
            else:
                y = y_clean
            d = tf.to_float(tf.less(0.5, y))
            y_discrete = tf.stop_gradient(d) + y - tf.stop_gradient(y)
            pd = common_layers.inverse_exp_decay(hparams.startup_steps * 2)
            pd *= hparams.d_mix
            pd = pd if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
            c = tf.cond(tf.less(tf.random_uniform([]), pd), lambda: y_discrete,
                        lambda: y)
            h1a = tf.layers.dense(c, filter_size, name="vch1a")
            h1b = tf.layers.dense(1.0 - c, filter_size, name="vch1b")
            h1 = h1a + h1b
            dx = tf.to_int32(tf.stop_gradient(d))
            c = bit_to_int(dx, z_size)
        if hparams.bottleneck_kind == "gumbel-softmax":
            _, hot, l = dae(x, hparams, name)
            c = tf.argmax(hot, axis=-1)
            h1 = tf.layers.dense(hot, hparams.hidden_size, name="dae_dense")
        if hparams.bottleneck_kind == "vq-vae":
            means = tf.get_variable(
                name="means", shape=[hparams.v_size, hparams.hidden_size])
            x_means_hot, x_means, l = kmeans(x,
                                             means,
                                             hparams,
                                             name="vq-vae-kmeans")
            h1 = x_means
            c = tf.argmax(x_means_hot, axis=-1)
        h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2")
        res = tf.layers.dense(tf.nn.relu(h2),
                              hparams.hidden_size,
                              name="vcfin")
        return res, c, l, embed