def matmul_eval( x, weight_parameters, transpose_a=False, transpose_b=False, gamma=common.GAMMA, zeta=common.ZETA): """Evaluation computation for a l0-regularized matmul. Args: x: 2D Tensor representing the input batch. weight_parameters: 2-tuple of Tensors, where the first tensor is the unscaled weight values and the second is the log of the alpha values for the hard concrete distribution. transpose_a: If True, a is transposed before multiplication. transpose_b: If True, b is transposed before multiplication. gamma: The gamma parameter, which controls the lower bound of the stretched distribution. Defaults to -0.1 from the above paper. zeta: The zeta parameters, which controls the upper bound of the stretched distribution. Defaults to 1.1 from the above paper. Returns: Output Tensor of the matmul operation. Raises: RuntimeError: If the weight_parameters argument is not a 2-tuple. """ x.get_shape().assert_has_rank(2) theta, log_alpha = _verify_weight_parameters(weight_parameters) # Use the mean of the learned hard-concrete distribution as the # deterministic weight noise at evaluation time weight_noise = common.hard_concrete_mean(log_alpha, gamma, zeta) weights = theta * weight_noise return tf.matmul(x, weights, transpose_a=transpose_a, transpose_b=transpose_b)
def build(self, _): """Initializes the weights for the RNN.""" with ops.init_scope(): theta, log_alpha = self._weight_parameters if self._training: weight_noise = common.hard_concrete_sample( log_alpha, self._beta, self._gamma, self._zeta, self._eps) else: weight_noise = common.hard_concrete_mean( log_alpha, self._gamma, self._zeta) self._weights = weight_noise * theta self.built = True
def broadcast_matmul_eval( x, weight_parameters, gamma=common.GAMMA, zeta=common.ZETA): """Evaluation computation for l0 matrix multiplication with N input matrices. Multiplies a 3D tensor `x` with a set of 2D parameters. Each 2D matrix `x[i, :, :]` in the input tensor is multiplied independently with the parameters, resulting in a 3D output tensor with shape `x.shape[:2] + weight_parameters[0].shape[1]`. Args: x: 3D Tensor representing the input batch. weight_parameters: 2-tuple of Tensors, where the first tensor is the unscaled weight values and the second is the log of the alpha values for the hard concrete distribution. gamma: The gamma parameter, which controls the lower bound of the stretched distribution. Defaults to -0.1 from the above paper. zeta: The zeta parameters, which controls the upper bound of the stretched distribution. Defaults to 1.1 from the above paper. Returns: Output Tensor of the batched matmul operation. Raises: RuntimeError: If the weight_parameters argument is not a 2-tuple. """ theta, log_alpha = _verify_weight_parameters(weight_parameters) theta.get_shape().assert_has_rank(2) # The input data must have be rank 2 or greater assert x.get_shape().ndims >= 2 input_rank = x.get_shape().ndims # Use the mean of the learned hard-concrete distribution as the # deterministic weight noise at evaluation time weight_noise = common.hard_concrete_mean(log_alpha, gamma, zeta) weights = theta * weight_noise # Compute the batch of matmuls return tf.tensordot(x, weights, [[input_rank-1], [0]])
def embedding_lookup_eval( weight_parameters, ids, name=None, gamma=common.GAMMA, zeta=common.ZETA): """Evaluation computation for a l0-regularized embedding lookup. Args: weight_parameters: 2-tuple of Tensors, where the first tensor is the unscaled weight values and the second is the log of the alpha values for the hard concrete distribution. ids: A Tensor with type int32 or int64 containing the ids to be looked up in params. name: String. Name of the operator. gamma: The gamma parameter, which controls the lower bound of the stretched distribution. Defaults to -0.1 from the above paper. zeta: The zeta parameters, which controls the upper bound of the stretched distribution. Defaults to 1.1 from the above paper. Returns: Output Tensor of the embedding lookup. Raises: RuntimeError: If the weight_parameters argument is not a 2-tuple. """ theta, log_alpha = _verify_weight_parameters(weight_parameters) # Before we do anything, lookup the theta values and log_alpha # values so that we can do our sampling and weight scaling in # the lower dimensional output batch embedding_theta = layer_utils.gather(theta, ids) embedding_log_alpha = layer_utils.gather(log_alpha, ids) # Calculate the mean of the learned hard-concrete distribution # and scale the output embedding vectors embedding_noise = common.hard_concrete_mean( embedding_log_alpha, gamma, zeta) return tf.identity(embedding_theta * embedding_noise, name=name)
def conv2d_eval( x, weight_parameters, strides, padding, data_format="NHWC", gamma=common.GAMMA, zeta=common.ZETA): """Evaluation computation for a l0-regularized conv2d. Args: x: NHWC tf.Tensor representing the input batch of features. weight_parameters: 2-tuple of Tensors, where the first tensor is the unscaled weight values and the second is the log of the alpha values for the hard concrete distribution. strides: The stride of the sliding window for each dimension of 'x'. Identical to standard strides argument for tf.conv2d. padding: String. One of "SAME", or "VALID". Identical to standard padding argument for tf.conv2d. data_format: 'NHWC' or 'NCHW' ordering of 4-D input Tensor. gamma: The gamma parameter, which controls the lower bound of the stretched distribution. Defaults to -0.1 from the above paper. zeta: The zeta parameters, which controls the upper bound of the stretched distribution. Defaults to 1.1 from the above paper. Returns: Output Tensor of the conv2d operation. Raises: RuntimeError: If the weight_parameters argument is not a 2-tuple. """ theta, log_alpha = _verify_weight_parameters(weight_parameters) # Use the mean of the learned hard-concrete distribution as the # deterministic weight noise at evaluation time weight_noise = common.hard_concrete_mean(log_alpha, gamma, zeta) weights = theta * weight_noise return tf.nn.conv2d(x, weights, strides, padding, data_format=data_format)