Exemplo n.º 1
0
def reshape_like(a, b):
    """Reshapes a to match the shape of b in all but the last dimension."""
    ret = tf.reshape(a, tf.concat([tf.shape(b)[:-1], tf.shape(a)[-1:]], 0))
    if not contrib_eager.in_eager_mode():
        ret.set_shape(b.get_shape().as_list()[:-1] +
                      a.get_shape().as_list()[-1:])
    return ret
Exemplo n.º 2
0
def _get_weights(model_hparams, vocab_size, hidden_dim=None):
  """Copied from tensor2tensor/layers/modalities.py but uses total vocab."""
  if hidden_dim is None:
    hidden_dim = model_hparams.hidden_size
  num_shards = model_hparams.symbol_modality_num_shards
  shards = []
  for i in range(num_shards):
    shard_size = (sum(vocab_size) // num_shards) + (
        1 if i < sum(vocab_size) % num_shards else 0)
    var_name = 'weights_%d' % i
    shards.append(
        tf.get_variable(
            var_name, [shard_size, hidden_dim],
            initializer=tf.random_normal_initializer(0.0, hidden_dim**-0.5)))
  if num_shards == 1:
    ret = shards[0]
  else:
    ret = tf.concat(shards, 0)
  # Convert ret to tensor.
  if not contrib_eager.in_eager_mode():
    ret = common_layers.convert_gradient_to_tensor(ret)
  return ret
def _get_weights(model_hparams, vocab_size, hidden_dim=None):
  """Create or get concatenated embedding or softmax variable.

  Args:
    model_hparams: tf.HParams, model hyperparmeters.
    vocab_size: int, vocabulary size.
    hidden_dim: dim of the variable. Defaults to model_hparams.hidden_size

  Returns:
     a list of num_shards Tensors.
  """
  if hidden_dim is None:
    hidden_dim = model_hparams.hidden_size
  num_shards = model_hparams.symbol_modality_num_shards
  shards = []

  sparsity_technique = model_hparams.get("sparsity_technique")
  aux_params_shards = []
  for i in range(num_shards):
    shard_size = (vocab_size // num_shards) + (
        1 if i < vocab_size % num_shards else 0)
    var_name = "weights_%d" % i

    weight_init_stddev = hidden_dim**-0.5
    if (model_hparams.get("load_masks_from") and
        model_hparams.get("initial_sparsity")):
      # If we are loading constant masks for scratch-e or scratch-b
      # experiments, we optionally rescale the variance of the weight
      # initialization.
      initial_sparsity = model_hparams.get("initial_sparsity")
      weight_init_stddev = (hidden_dim * (1 - initial_sparsity))**-0.5
      tf.logging.info("Using sparse initialization with sparsity {} for symbol "
                      .format(initial_sparsity))

    shards.append(
        tf.get_variable(
            var_name, [shard_size, hidden_dim],
            initializer=tf.random_normal_initializer(0.0, weight_init_stddev)))
    if sparsity_technique == "variational_dropout":
      aux_params_shards.append(
          tf.get_variable(
              var_name + "_aux", [shard_size, hidden_dim],
              initializer=tf.constant_initializer(value=-10.0)))
    elif sparsity_technique == "l0_regularization":
      initializer = tf.random_normal_initializer(mean=2.197, stddev=0.01)
      aux_params_shards.append(
          tf.get_variable(
              var_name + "_aux", [shard_size, hidden_dim],
              initializer=initializer))

  if num_shards == 1:
    ret = shards[0]
  else:
    ret = tf.concat(shards, 0)

  if not aux_params_shards:
    # Convert ret to tensor.
    if not contrib_eager.in_eager_mode():
      ret = common_layers.convert_gradient_to_tensor(ret)
    return ret

  # Handle the auxiliary parameters
  if num_shards == 1:
    aux_ret = aux_params_shards[0]
  else:
    aux_ret = tf.concat(aux_params_shards, 0)

  global COLLECTED_VARIABLES
  if not COLLECTED_VARIABLES:
    if sparsity_technique == "variational_dropout":
      tf.add_to_collection(
          common_sparse.VARIATIONAL_DROPOUT_PARAMETERS,
          (ret, aux_ret))
    elif sparsity_technique == "l0_regularization":
      tf.add_to_collection(
          common_sparse.L0_REGULARIZATION_PARAMETERS,
          (ret, aux_ret))
    COLLECTED_VARIABLES = True

  # Convert aux ret to tensor.
  if not contrib_eager.in_eager_mode():
    ret = common_layers.convert_gradient_to_tensor(ret)
    aux_ret = common_layers.convert_gradient_to_tensor(aux_ret)
  return (ret, aux_ret)