Exemple #1
0
 def bottom_simple(self, x, name, reuse):
   with tf.variable_scope(name, reuse=reuse):
     # Squeeze out the channels dimension.
     x = tf.squeeze(x, axis=3)
     var = self._get_weights()
     x = common_layers.dropout_no_scaling(
         x, 1.0 - self._model_hparams.symbol_dropout)
     ret = common_layers.gather(var, x)
     if self._model_hparams.multiply_embedding_mode == "sqrt_depth":
       ret *= self._body_input_depth**0.5
     ret *= tf.expand_dims(tf.to_float(tf.not_equal(x, 0)), -1)
     return ret
Exemple #2
0
 def bottom_simple(self, x, name, reuse):
     with tf.variable_scope(name, reuse=reuse):
         var = self._get_weights()
         x = common_layers.dropout_no_scaling(
             x, 1.0 - self._model_hparams.symbol_dropout)
         # Add together the embeddings for each tuple position.
         ret = tf.add_n([
             tf.gather(var, x[:, :, :, i] + sum(self._vocab_size[:i])) *
             tf.expand_dims(tf.to_float(tf.not_equal(x[:, :, :, i], 0)), -1)
             for i in range(len(self._vocab_size))
         ])
         if self._model_hparams.multiply_embedding_mode == 'sqrt_depth':
             ret *= self._body_input_depth**0.5
         return ret
Exemple #3
0
 def bottom_simple(self, x, name, reuse):
   with tf.variable_scope(name, reuse=reuse):
     var = self._get_weights()
     x = common_layers.dropout_no_scaling(
         x, 1.0 - self._model_hparams.symbol_dropout)
     # Add together the embeddings for each tuple position.
     ret = tf.add_n([
         tf.gather(var, x[:, :, :, i] + sum(self._vocab_size[:i])) *
         tf.expand_dims(tf.to_float(tf.not_equal(x[:, :, :, i], 0)), -1)
         for i in range(len(self._vocab_size))
     ])
     if self._model_hparams.multiply_embedding_mode == 'sqrt_depth':
       ret *= self._body_input_depth**0.5
     return ret
Exemple #4
0
def bottom_simple(x, model_hparams, vocab_size, name, reuse):
  """Internal bottom transformation."""
  with tf.variable_scope(name, reuse=reuse):
    var = _get_weights(model_hparams, vocab_size)
    x = common_layers.dropout_no_scaling(
        x, 1.0 - model_hparams.symbol_dropout)
    # Add together the embeddings for each tuple position.
    ret = tf.add_n([
        tf.gather(var, x[:, :, :, i] + sum(vocab_size[:i])) *
        tf.expand_dims(tf.to_float(tf.not_equal(x[:, :, :, i], 0)), -1)
        for i in range(len(vocab_size))
    ])
    if model_hparams.multiply_embedding_mode == 'sqrt_depth':
      ret *= model_hparams.hidden_size**0.5
    return ret
Exemple #5
0
def bottom_simple(x, model_hparams, vocab_size, name, reuse):
  """Internal bottom transformation."""
  with tf.variable_scope(name, reuse=reuse):
    var = _get_weights(model_hparams, vocab_size)
    x = common_layers.dropout_no_scaling(
        x, 1.0 - model_hparams.symbol_dropout)
    # Add together the embeddings for each tuple position.
    ret = tf.add_n([
        tf.gather(var, x[:, :, :, i] + sum(vocab_size[:i])) *
        tf.expand_dims(tf.to_float(tf.not_equal(x[:, :, :, i], 0)), -1)
        for i in range(len(vocab_size))
    ])
    if model_hparams.multiply_embedding_mode == 'sqrt_depth':
      ret *= model_hparams.hidden_size**0.5
    return ret
Exemple #6
0
 def bottom_simple(self, x, name, reuse):
     with tf.variable_scope(name, reuse=reuse):
         # Ensure the inputs are 3-D
         if len(x.get_shape()) == 4:
             x = tf.squeeze(x, axis=3)
         while len(x.get_shape()) < 3:
             x = tf.expand_dims(x, axis=-1)
         var = self._get_weights()
         x = common_layers.dropout_no_scaling(
             x, 1.0 - self._model_hparams.symbol_dropout)
         ret = common_layers.gather(var, x)
         if self._model_hparams.multiply_embedding_mode == "sqrt_depth":
             ret *= self._body_input_depth**0.5
         ret *= tf.expand_dims(tf.to_float(tf.not_equal(x, 0)), -1)
         return ret
Exemple #7
0
  def bottom_simple(self, x, name, reuse):
    with tf.variable_scope(name, reuse=reuse):
      # Ensure the inputs are 3-D
      if len(x.get_shape()) == 4:
        x = tf.squeeze(x, axis=3)
      while len(x.get_shape()) < 3:
        x = tf.expand_dims(x, axis=-1)

      var = self._get_weights()
      x = common_layers.dropout_no_scaling(
          x, 1.0 - self._model_hparams.symbol_dropout)
      ret = common_layers.gather(var, x)
      if self._model_hparams.multiply_embedding_mode == "sqrt_depth":
        ret *= self._body_input_depth**0.5
      ret *= tf.expand_dims(tf.to_float(tf.not_equal(x, 0)), -1)
      return ret
def bottom_simple(x, model_hparams, vocab_size, name, reuse):
  """Bottom transformation."""
  with tf.variable_scope(name, reuse=reuse):
    # Ensure the inputs are 3-D
    if len(x.get_shape()) == 4:
      x = tf.squeeze(x, axis=3)
    while len(x.get_shape()) < 3:
      x = tf.expand_dims(x, axis=-1)

    var = _get_weights(model_hparams, vocab_size)
    x = common_layers.dropout_no_scaling(
        x, 1.0 - model_hparams.symbol_dropout)

    sparsity_technique = model_hparams.get("sparsity_technique")
    training = model_hparams.get("mode") == tf.estimator.ModeKeys.TRAIN
    if sparsity_technique == "variational_dropout":
      if training:
        ret = vd.nn.embedding_lookup_train(
            var,
            x,
            clip_alpha=model_hparams.get("clip_log_alpha"))
      else:
        threshold = model_hparams.get("log_alpha_threshold")
        ret = vd.nn.embedding_lookup_eval(
            var,
            x,
            threshold=threshold)
    elif sparsity_technique == "l0_regularization":
      if training:
        ret = l0.nn.embedding_lookup_train(var, x)
      else:
        ret = l0.nn.embedding_lookup_eval(var, x)
    elif (sparsity_technique == "magnitude_pruning" or
          sparsity_technique == "random_pruning"):
      ret = common_layers.gather(pruning.apply_mask(var), x)
    else:
      ret = common_layers.gather(var, x)

    # post-process the embedding vectors
    if model_hparams.multiply_embedding_mode == "sqrt_depth":
      ret *= model_hparams.hidden_size**0.5
    ret *= tf.expand_dims(tf.to_float(tf.not_equal(x, 0)), -1)
    return ret