Exemplo n.º 1
0
def matmul_eval(
    x,
    weight_parameters,
    transpose_a=False,
    transpose_b=False,
    gamma=common.GAMMA,
    zeta=common.ZETA):
  """Evaluation computation for a l0-regularized matmul.

  Args:
    x: 2D Tensor representing the input batch.
    weight_parameters: 2-tuple of Tensors, where the first tensor is the
      unscaled weight values and the second is the log of the alpha values
      for the hard concrete distribution.
    transpose_a: If True, a is transposed before multiplication.
    transpose_b: If True, b is transposed before multiplication.
    gamma: The gamma parameter, which controls the lower bound of the
      stretched distribution. Defaults to -0.1 from the above paper.
    zeta: The zeta parameters, which controls the upper bound of the
      stretched distribution. Defaults to 1.1 from the above paper.

  Returns:
    Output Tensor of the matmul operation.

  Raises:
    RuntimeError: If the weight_parameters argument is not a 2-tuple.
  """
  x.get_shape().assert_has_rank(2)
  theta, log_alpha = _verify_weight_parameters(weight_parameters)

  # Use the mean of the learned hard-concrete distribution as the
  # deterministic weight noise at evaluation time
  weight_noise = common.hard_concrete_mean(log_alpha, gamma, zeta)
  weights = theta * weight_noise
  return tf.matmul(x, weights, transpose_a=transpose_a, transpose_b=transpose_b)
Exemplo n.º 2
0
 def build(self, _):
     """Initializes the weights for the RNN."""
     with ops.init_scope():
         theta, log_alpha = self._weight_parameters
         if self._training:
             weight_noise = common.hard_concrete_sample(
                 log_alpha, self._beta, self._gamma, self._zeta, self._eps)
         else:
             weight_noise = common.hard_concrete_mean(
                 log_alpha, self._gamma, self._zeta)
         self._weights = weight_noise * theta
     self.built = True
Exemplo n.º 3
0
def broadcast_matmul_eval(
    x,
    weight_parameters,
    gamma=common.GAMMA,
    zeta=common.ZETA):
  """Evaluation computation for l0 matrix multiplication with N input matrices.

  Multiplies a 3D tensor `x` with a set of 2D parameters. Each 2D matrix
  `x[i, :, :]` in the input tensor is multiplied independently with the
  parameters, resulting in a 3D output tensor with shape
  `x.shape[:2] + weight_parameters[0].shape[1]`.

  Args:
    x: 3D Tensor representing the input batch.
    weight_parameters: 2-tuple of Tensors, where the first tensor is the
      unscaled weight values and the second is the log of the alpha values
      for the hard concrete distribution.
    gamma: The gamma parameter, which controls the lower bound of the
      stretched distribution. Defaults to -0.1 from the above paper.
    zeta: The zeta parameters, which controls the upper bound of the
      stretched distribution. Defaults to 1.1 from the above paper.

  Returns:
    Output Tensor of the batched matmul operation.

  Raises:
    RuntimeError: If the weight_parameters argument is not a 2-tuple.
  """
  theta, log_alpha = _verify_weight_parameters(weight_parameters)
  theta.get_shape().assert_has_rank(2)

  # The input data must have be rank 2 or greater
  assert x.get_shape().ndims >= 2
  input_rank = x.get_shape().ndims

  # Use the mean of the learned hard-concrete distribution as the
  # deterministic weight noise at evaluation time
  weight_noise = common.hard_concrete_mean(log_alpha, gamma, zeta)
  weights = theta * weight_noise

  # Compute the batch of matmuls
  return tf.tensordot(x, weights, [[input_rank-1], [0]])
Exemplo n.º 4
0
def embedding_lookup_eval(
    weight_parameters,
    ids,
    name=None,
    gamma=common.GAMMA,
    zeta=common.ZETA):
  """Evaluation computation for a l0-regularized embedding lookup.

  Args:
    weight_parameters: 2-tuple of Tensors, where the first tensor is the
      unscaled weight values and the second is the log of the alpha values
      for the hard concrete distribution.
    ids: A Tensor with type int32 or int64 containing the ids to be looked up
      in params.
    name: String. Name of the operator.
    gamma: The gamma parameter, which controls the lower bound of the
      stretched distribution. Defaults to -0.1 from the above paper.
    zeta: The zeta parameters, which controls the upper bound of the
      stretched distribution. Defaults to 1.1 from the above paper.

  Returns:
    Output Tensor of the embedding lookup.

  Raises:
    RuntimeError: If the weight_parameters argument is not a 2-tuple.
  """
  theta, log_alpha = _verify_weight_parameters(weight_parameters)

  # Before we do anything, lookup the theta values and log_alpha
  # values so that we can do our sampling and weight scaling in
  # the lower dimensional output batch
  embedding_theta = layer_utils.gather(theta, ids)
  embedding_log_alpha = layer_utils.gather(log_alpha, ids)

  # Calculate the mean of the learned hard-concrete distribution
  # and scale the output embedding vectors
  embedding_noise = common.hard_concrete_mean(
      embedding_log_alpha,
      gamma,
      zeta)
  return tf.identity(embedding_theta * embedding_noise, name=name)
Exemplo n.º 5
0
def conv2d_eval(
    x,
    weight_parameters,
    strides,
    padding,
    data_format="NHWC",
    gamma=common.GAMMA,
    zeta=common.ZETA):
  """Evaluation computation for a l0-regularized conv2d.

  Args:
    x: NHWC tf.Tensor representing the input batch of features.
    weight_parameters: 2-tuple of Tensors, where the first tensor is the
      unscaled weight values and the second is the log of the alpha values
      for the hard concrete distribution.
    strides: The stride of the sliding window for each dimension of 'x'.
      Identical to standard strides argument for tf.conv2d.
    padding: String. One of "SAME", or "VALID". Identical to standard
      padding argument for tf.conv2d.
    data_format: 'NHWC' or 'NCHW' ordering of 4-D input Tensor.
    gamma: The gamma parameter, which controls the lower bound of the
      stretched distribution. Defaults to -0.1 from the above paper.
    zeta: The zeta parameters, which controls the upper bound of the
      stretched distribution. Defaults to 1.1 from the above paper.

  Returns:
    Output Tensor of the conv2d operation.

  Raises:
    RuntimeError: If the weight_parameters argument is not a 2-tuple.
  """
  theta, log_alpha = _verify_weight_parameters(weight_parameters)

  # Use the mean of the learned hard-concrete distribution as the
  # deterministic weight noise at evaluation time
  weight_noise = common.hard_concrete_mean(log_alpha, gamma, zeta)
  weights = theta * weight_noise
  return tf.nn.conv2d(x, weights, strides, padding, data_format=data_format)