def test_tensordot(self):
   num_trials = min(30, num_dims_ * num_dims_)
   if dtype_ == np.float16:
     tol = 0.05
   elif dtype_ == np.float32 or dtype_ == np.complex64:
     tol = 1e-5
   else:
     tol = 1e-12
   for _ in range(num_trials):
     a_np, b_np, a_dims_np, b_dims_np = _generate_random_tensors_and_dims()
     np_ans = np.tensordot(a_np, b_np, axes=(a_dims_np, b_dims_np))
     with self.test_session(use_gpu=True) as sess:
       if dynamic_shape_:
         a = array_ops.placeholder(dtype_)
         b = array_ops.placeholder(dtype_)
         axes = array_ops.placeholder(dtypes.int32)
         c = math_ops.tensordot(a, b, axes)
         tf_ans = sess.run(
             c, feed_dict={a: a_np,
                           b: b_np,
                           axes: (a_dims_np, b_dims_np)})
       else:
         tf_ans = math_ops.tensordot(a_np, b_np, (a_dims_np, b_dims_np)).eval()
     self.assertAllClose(tf_ans, np_ans, rtol=tol, atol=tol)
     self.assertAllEqual(tf_ans.shape, np_ans.shape)
 def test_tensordot_scalar_axes(self):
   if num_dims_ < 1:
     self.skipTest("Not a test")
   if dtype_ == np.float16:
     tol = 0.05
   elif dtype_ == np.float32 or dtype_ == np.complex64:
     tol = 1e-5
   else:
     tol = 1e-12
   shape = [5] * num_dims_
   a_np = np.random.uniform(
       low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype_)
   b_np = np.random.uniform(
       low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype_)
   all_axes = [1]
   if a_np.ndim > 1:
     all_axes.append(a_np.ndim - 1)
   for axes in all_axes:
     np_ans = np.tensordot(a_np, b_np, axes=axes)
     with self.test_session(use_gpu=True) as sess:
       if dynamic_shape_:
         a = array_ops.placeholder(dtype_)
         b = array_ops.placeholder(dtype_)
         c = math_ops.tensordot(a, b, axes=axes)
         tf_ans = sess.run(c, feed_dict={a: a_np, b: b_np})
       else:
         tf_ans = math_ops.tensordot(a_np, b_np, axes=axes).eval()
     self.assertAllClose(tf_ans, np_ans, rtol=tol, atol=tol)
     self.assertAllEqual(tf_ans.shape, np_ans.shape)
  def test_invalid_axes(self):
    a = [[1, 2], [3, 4]]
    b = [[1, 2], [3, 4]]
    # Invalid static axes.
    for axes_value in -1, 3, [1], [[1]], [[1], [0, 1]]:
      with self.assertRaises(ValueError):
        math_ops.tensordot(a, b, axes_value)

    with self.assertRaises(IndexError):
      math_ops.tensordot(a, b, [[0], [7]])

    # Invalid dynamic axes.
    a_ph = array_ops.placeholder(dtypes.float32)
    b_ph = array_ops.placeholder(dtypes.float32)
    axes_ph = array_ops.placeholder(dtypes.int32)
    output = math_ops.tensordot(a_ph, b_ph, axes_ph)
    # Note: We don't support scalar Tensor values for axes.
    for axes_value in 1, [1], [0, 1], [[1]], [[0, 1]], [[0], [7]]:
      with self.cached_session() as sess:
        with self.assertRaises(errors_impl.InvalidArgumentError):
          _ = sess.run(
              [output], feed_dict={
                  a_ph: a,
                  b_ph: b,
                  axes_ph: axes_value
              })
Beispiel #4
0
  def _update_mat_g(self, mat_g, grad, axes, mat_gbar_decay,
                    mat_gbar_weight, i):
    """Updates the cumulative outer products of the gradients.

    Args:
      mat_g: the matrix to be updated
      grad: the gradient of the variable
      axes: a list of k-1 integers 0 to k-1, except i
      mat_gbar_decay: constant for weighted average:
          mat_g = mat_g * decay + grad * weight
      mat_gbar_weight: constant for weighted average
      i: index of dimension to be updated.

    Returns:
      updated mat_g = mat_g * mat_gbar_decay + grad_outer * mat_gbar_weight

    In Einstein notation if i = 0: grad_outer_aa'= g_abcd g_a'bcd
    thus grad_outer is a matrix d_i x d_i, where d_i is the size of the
    i'th dimension of g.
    Alternate view: If mat_i(grad) is the flattening of grad to a
    d_i x (d_1d_2...d_{i-1}d_{i+1}...d_k) matrix, then
         grad_outer = mat_i(grad) mat_i(grad).transpose
    """
    grad_outer = math_ops.tensordot(grad, grad, axes=(axes, axes),
                                    name="grad_outer_" + str(i))
    return self._weighted_average(mat_g, self._mat_gbar_decay, mat_gbar_decay,
                                  mat_gbar_weight * grad_outer)
 def test_no_partial_shape_inference(self):
   # If one of the shapes is only partially defined, the output shape is
   # unknown.
   a = array_ops.placeholder(dtypes.float32)
   b = array_ops.placeholder(dtypes.float32)
   axes = ([1], [0])
   output = math_ops.tensordot(a, b, axes)
   self.assertEqual(output.get_shape().ndims, None)
   a.set_shape([None, 2])
   b.set_shape([2, 3])
   output = math_ops.tensordot(a, b, axes)
   self.assertEqual(output.get_shape().ndims, None)
   a = array_ops.placeholder(dtypes.float32)
   b = array_ops.placeholder(dtypes.float32)
   a.set_shape([2, 2])
   b.set_shape([2, None])
   output = math_ops.tensordot(a, b, axes)
   self.assertEqual(output.get_shape().ndims, None)
 def test_invalid_shape(self):
   a = [[1, 2], [3, 4]]
   b = [[1, 2], [3, 4], [5, 6]]
   a_axes = [1]
   b_axes = [0]
   # Invalid static shapes.
   with self.assertRaises(ValueError):
     math_ops.tensordot(a, b, (a_axes, b_axes))
   # Invalid dynamic shapes.
   with self.test_session() as sess:
     with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
                                  "Matrix size-incompatible"):
       a_ph = array_ops.placeholder(dtypes.float32)
       b_ph = array_ops.placeholder(dtypes.float32)
       axes_ph = array_ops.placeholder(dtypes.int32)
       output = math_ops.tensordot(a_ph, b_ph, axes_ph)
       _ = sess.run(
           [output], feed_dict={a_ph: a,
                                b_ph: b,
                                axes_ph: (a_axes, b_axes)})
  def test_valid_axis(self):
    for axes_value in [1, 2], [[1], [2]], [[], []], 0:
      with self.cached_session():
        np_a = np.ones((3, 3))
        np_b = np.array([2, 3, 1])[None, None]
        np_ans = np.tensordot(np_a, np_b, axes_value)

        tf_a = array_ops.ones((3, 3), dtype=dtypes.float32)
        tf_b = constant_op.constant([2, 3, 1], dtype=dtypes.float32)[None, None]
        tf_ans = math_ops.tensordot(tf_a, tf_b, axes_value)

        self.assertAllEqual(tf_ans.shape, np_ans.shape)
        self.assertAllEqual(tf_ans, np_ans)
 def test_partial_shape_inference(self):
   a = array_ops.placeholder(dtypes.float32)
   b = array_ops.placeholder(dtypes.float32)
   axes = ([1], [0])
   output = math_ops.tensordot(a, b, axes)
   self.assertEqual(output.get_shape().ndims, None)
   a.set_shape([None, 2])
   b.set_shape([2, 3])
   output = math_ops.tensordot(a, b, axes)
   output_shape = output.get_shape()
   self.assertEqual(output_shape.ndims, 2)
   output_shape = output_shape.as_list()
   self.assertEqual(output_shape[0], None)
   self.assertEqual(output_shape[1], 3)
   a = array_ops.placeholder(dtypes.float32)
   b = array_ops.placeholder(dtypes.float32)
   a.set_shape([2, 2])
   b.set_shape([2, None])
   output = math_ops.tensordot(a, b, axes)
   output_shape = output.get_shape()
   self.assertEqual(output_shape.ndims, 2)
   output_shape = output_shape.as_list()
   self.assertEqual(output_shape[0], 2)
   self.assertEqual(output_shape[1], None)
Beispiel #9
0
def embedding_matmul(embedding_table, values, mask, name='embedding_matmul'):
  """Performs embedding lookup via a matmul.

  The matrix to be multiplied by the embedding table Tensor is constructed
  via an implementation of scatter based on broadcasting embedding indices
  and performing an equality comparison against a broadcasted
  range(num_embedding_table_rows). All masked positions will produce an
  embedding vector of zeros.

  Args:
    embedding_table: Tensor of embedding table.
      Rank 2 (table_size x embedding dim)
    values: Tensor of embedding indices. Rank 2 (batch x n_indices)
    mask: Tensor of mask / weights. Rank 2 (batch x n_indices)
    name: Optional name scope for created ops

  Returns:
    Rank 3 tensor of embedding vectors.
  """

  with ops.name_scope(name):
    n_embeddings, embedding_dim = embedding_table.get_shape().as_list()
    batch_size, padded_size = values.shape.as_list()

    emb_idcs = array_ops.tile(
        array_ops.reshape(values, (batch_size, padded_size, 1)), (1, 1,
                                                                  n_embeddings))
    emb_weights = array_ops.tile(
        array_ops.reshape(mask, (batch_size, padded_size, 1)),
        (1, 1, n_embeddings))
    col_idcs = array_ops.tile(
        array_ops.reshape(math_ops.range(n_embeddings), (1, 1, n_embeddings)),
        (batch_size, padded_size, 1))
    one_hot = array_ops.where(
        math_ops.equal(emb_idcs, col_idcs), emb_weights,
        array_ops.zeros((batch_size, padded_size, n_embeddings)))

    return math_ops.tensordot(one_hot, embedding_table, 1)
Beispiel #10
0
def dot(lhs, rhs, name=None):
  return math_ops.tensordot(lhs, rhs, axes=1, name=name)
Beispiel #11
0
  def call(self, inputs):
    if K.dtype(inputs) != 'int32':
      inputs = math_ops.cast(inputs, 'int32')

    inputs = array_ops.one_hot(inputs, self.input_dim)
    return math_ops.tensordot(inputs, self.embeddings, 1)
Beispiel #12
0
    def _apply_gradient(self, grad, var, indices=None):
        """The main function to update a variable.

    Args:
      grad: A Tensor containing gradient to apply.
      var: A Tensor containing the variable to update.
      indices: An array of integers, for sparse update.

    Returns:
      Updated variable var = var - learning_rate * preconditioner * grad

    If the gradient is dense, var and grad have the same shape.
    If the update is sparse, then the first dimension of the gradient and var
    may differ, others are all the same. In this case the indices array
    provides the set of indices of the variable which are to be updated with
    each row of the gradient.
    """
        global_step = self._global_step + 1

        # Update accumulated weighted average of gradients
        gbar = self.get_slot(var, "gbar")
        gbar_decay_t = GetParam(self._gbar_decay, global_step)
        gbar_weight_t = GetParam(self._gbar_weight, global_step)
        if indices is not None:
            # Note - the sparse update is not easily implemented, since the
            # algorithm needs all indices of gbar to be updated
            # if mat_gbar_decay != 1 or mat_gbar_decay != 0.
            # One way to make mat_gbar_decay = 1 is by rescaling.
            # If we want the update:
            #         G_{t+1} = a_{t+1} G_t + b_{t+1} w_t
            # define:
            #         r_{t+1} = a_{t+1} * r_t
            #         h_t = G_t / r_t
            # Then:
            #         h_{t+1} = h_t + (b_{t+1} / r_{t+1}) * w_t
            # So we get the mat_gbar_decay = 1 as desired.
            # We can implement this in a future version as needed.
            # However we still need gbar_decay = 0, otherwise all indices
            # of the variable will need to be updated.
            if self._gbar_decay != 0.0:
                tf_logging.warning("Not applying momentum for variable: %s" %
                                   var.name)
            gbar_updated = grad
        else:
            gbar_updated = self._weighted_average(gbar, self._gbar_decay,
                                                  gbar_decay_t,
                                                  gbar_weight_t * grad)

        # Update the preconditioners and compute the preconditioned gradient
        shape = var.get_shape()
        mat_g_list = []
        for i in range(len(shape)):
            mat_g_list.append(self.get_slot(var, "Gbar_" + str(i)))
        mat_gbar_decay_t = GetParam(self._mat_gbar_decay, global_step)
        mat_gbar_weight_t = GetParam(self._mat_gbar_weight, global_step)

        preconditioned_grad = gbar_updated
        v_rank = len(mat_g_list)
        neg_alpha = -GetParam(self._alpha, global_step) / v_rank
        svd_interval = GetParam(self._svd_interval, global_step)
        precond_update_interval = GetParam(self._precond_update_interval,
                                           global_step)
        for i, mat_g in enumerate(mat_g_list):
            # axes is the list of indices to reduce - everything but the current i.
            axes = list(range(i)) + list(range(i + 1, v_rank))
            if shape[i] <= self._max_matrix_size:
                # If the tensor size is sufficiently small perform full Shampoo update
                # Note if precond_update_interval > 1 and mat_gbar_decay_t != 1, this
                # is not strictly correct. However we will use it for now, and
                # fix if needed. (G_1 = aG + bg ==> G_n = a^n G + (1+a+..+a^{n-1})bg)

                # pylint: disable=g-long-lambda,cell-var-from-loop
                mat_g_updated = control_flow_ops.cond(
                    math_ops.mod(global_step, precond_update_interval) < 1,
                    lambda: self._update_mat_g(
                        mat_g, grad, axes, mat_gbar_decay_t, mat_gbar_weight_t
                        * precond_update_interval, i), lambda: mat_g)

                mat_g_updated = mat_g_updated / float(shape[i].value)

                if self._svd_interval == 1:
                    mat_h = self._compute_power(var, mat_g_updated, shape[i],
                                                neg_alpha)
                else:
                    mat_h = control_flow_ops.cond(
                        math_ops.mod(global_step, svd_interval) < 1,
                        lambda: self._compute_power(var, mat_g_updated, shape[
                            i], neg_alpha, "H_" + str(i)),
                        lambda: self.get_slot(var, "H_" + str(i)))

                # mat_h is a square matrix of size d_i x d_i
                # preconditioned_grad is a d_i x ... x d_n x d_0 x ... d_{i-1} tensor
                # After contraction with a d_i x d_i tensor
                # it becomes a d_{i+1} x ... x d_n x d_0 x ... d_i tensor
                # (the first dimension is contracted out, and the second dimension of
                # mat_h is appended).  After going through all the indices, it becomes
                # a d_0 x ... x d_n tensor again.
                preconditioned_grad = math_ops.tensordot(preconditioned_grad,
                                                         mat_h,
                                                         axes=([0], [0]),
                                                         name="precond_" +
                                                         str(i))
            else:
                # Tensor size is too large -- perform diagonal Shampoo update
                # Only normalize non-vector cases.
                if axes:
                    normalizer = 1.0 if indices is not None else float(
                        shape[i].value)
                    grad_outer = math_ops.reduce_sum(grad * grad,
                                                     axis=axes) / normalizer
                else:
                    grad_outer = grad * grad

                if i == 0 and indices is not None:
                    assert self._mat_gbar_decay == 1.0
                    mat_g_updated = state_ops.scatter_add(
                        mat_g, indices, mat_gbar_weight_t * grad_outer)
                    mat_h = math_ops.pow(
                        array_ops.gather(mat_g_updated, indices) +
                        self._epsilon, neg_alpha)
                else:
                    mat_g_updated = self._weighted_average(
                        mat_g, self._mat_gbar_decay, mat_gbar_decay_t,
                        mat_gbar_weight_t * grad_outer)
                    mat_h = math_ops.pow(mat_g_updated + self._epsilon,
                                         neg_alpha)

                # Need to do the transpose to ensure that the tensor becomes
                # a d_{i+1} x ... x d_n x d_0 x ... d_i tensor as described above.
                preconditioned_grad = array_ops.transpose(
                    preconditioned_grad,
                    perm=list(range(1, v_rank)) + [0]) * mat_h

        # Update the variable based on the Shampoo update
        learning_rate_t = GetParam(self._learning_rate, global_step)
        if indices is not None:
            var_updated = state_ops.scatter_add(
                var, indices, -learning_rate_t * preconditioned_grad)
        else:
            var_updated = state_ops.assign_sub(
                var, learning_rate_t * preconditioned_grad)
        return var_updated
Beispiel #13
0
    def call(self, inputs):
        if K.dtype(inputs) != 'int32':
            inputs = math_ops.cast(inputs, 'int32')

        inputs = array_ops.one_hot(inputs, self.input_dim)
        return math_ops.tensordot(inputs, self.embeddings, 1)
Beispiel #14
0
 def _benchmark_tf_tensordot(self, device=CPU, execution_mode=None):
   with context.device(device):
     a = array_ops.ones((2, 2))
     b = array_ops.ones((2, 2))
     func = lambda: math_ops.tensordot(a, b, [[1], [0]])
     self._run(func, 30000, execution_mode=execution_mode)
Beispiel #15
0
def dot(lhs, rhs, name=None):
    return math_ops.tensordot(lhs, rhs, axes=1, name=name)
Beispiel #16
0
def linear(args,
           output_size,
           bias,
           handle=None,
           bias_initializer=None,
           kernel_initializer=None,
           kernel_name=_WEIGHTS_VARIABLE_NAME,
           bias_name=_BIAS_VARIABLE_NAME):
    """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.

    This function is originally copied from rnn_cell_impl.py and add
    the capability to deal with 3-D matrix.
    Args:
        args: a 2D/3D Tensor or a list of 2D/3D, batch x n, Tensors.
        output_size: int, second dimension of W[i].
        bias: boolean, whether to add a bias term or not.
        handle: A Tensor. If provided, use it
          as the weight matrix.
        bias_initializer: starting value to initialize the bias
          (default is all zeros).
        kernel_initializer: starting value to initialize the weight.

    Returns: A 2D/3D Tensor with shape [batch x output_size] equal to
        sum_i(args[i] * W[i]), where W[i]s are newly created matrices.

    Raises:
      ValueError: if some of the arguments has unspecified or wrong shape.
    """
    if args is None or (nest.is_sequence(args) and not args):
        raise ValueError("`args` must be specified")
    if not nest.is_sequence(args):  # added by Chengqi
        ## capable for 3D tensor
        shape = args.get_shape()
        if shape.ndims > 2:
            scope = vs.get_variable_scope()
            with vs.variable_scope(scope) as outer_scope:
                if handle is None:
                    weights = vs.get_variable(
                        kernel_name, [shape[-1].value, output_size],
                        dtype=args.dtype,
                        initializer=kernel_initializer)
                else:
                    assert output_size == handle.get_shape().as_list()[-1], \
                        "ouput_size should be the same as the last dimension of handle tensor"
                    weights = handle

                res = math_ops.tensordot(args, weights, [[shape.ndims - 1], [0]])

                if not bias:
                    return res
                with vs.variable_scope(outer_scope) as inner_scope:
                    inner_scope.set_partitioner(None)
                    if bias_initializer is None:
                        bias_initializer = init_ops.constant_initializer(0.0, dtype=args.dtype)
                    biases = vs.get_variable(
                        bias_name, [output_size],
                        dtype=args.dtype,
                        initializer=bias_initializer)
                return nn_ops.bias_add(res, biases)
        else:
            args = [args]

    # Calculate the total size of arguments on dimension 1.
    total_arg_size = 0
    shapes = [a.get_shape() for a in args]
    for shape in shapes:
        if shape.ndims != 2:
            raise ValueError("linear is expecting 2D arguments: %s" % shapes)
        if shape[1].value is None:
            raise ValueError("linear expects shape[1] to be provided for shape %s, "
                             "but saw %s" % (shape, shape[1]))
        else:
            total_arg_size += shape[1].value

    dtype = [a.dtype for a in args][0]

    # Now the computation.
    scope = vs.get_variable_scope()
    with vs.variable_scope(scope) as outer_scope:
        if handle is None:
            weights = vs.get_variable(
                kernel_name, [total_arg_size, output_size],
                dtype=dtype,
                initializer=kernel_initializer)
        else:
            assert output_size == handle.get_shape().as_list()[-1], \
                "ouput_size should be the same as the last dimension of handle tensor"
            weights = handle
        if len(args) == 1:
            res = math_ops.matmul(args[0], weights)
        else:
            res = math_ops.matmul(array_ops.concat(args, 1), weights)
        if not bias:
            return res
        with vs.variable_scope(outer_scope) as inner_scope:
            inner_scope.set_partitioner(None)
            if bias_initializer is None:
                bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype)
            biases = vs.get_variable(
                bias_name, [output_size],
                dtype=dtype,
                initializer=bias_initializer)
        return nn_ops.bias_add(res, biases)
Beispiel #17
0
 def f(a, b):
     return np_utils.cond(
         np_utils.logical_or(math_ops.equal(array_ops.rank(a), 0),
                             math_ops.equal(array_ops.rank(b), 0)),
         lambda: a * b, lambda: math_ops.tensordot(a, b, axes=[[-1], [-1]]))
Beispiel #18
0
def tensordot(a, b, axes=2):
    return _bin_op(lambda a, b: math_ops.tensordot(a, b, axes=axes), a, b)
Beispiel #19
0
  def _apply_gradient(self, grad, var, indices=None):
    """The main function to update a variable.

    Args:
      grad: A Tensor containing gradient to apply.
      var: A Tensor containing the variable to update.
      indices: An array of integers, for sparse update.

    Returns:
      Updated variable var = var - learning_rate * preconditioner * grad

    If the gradient is dense, var and grad have the same shape.
    If the update is sparse, then the first dimension of the gradient and var
    may differ, others are all the same. In this case the indices array
    provides the set of indices of the variable which are to be updated with
    each row of the gradient.
    """
    global_step = self._global_step + 1

    # Update accumulated weighted average of gradients
    gbar = self.get_slot(var, "gbar")
    gbar_decay_t = GetParam(self._gbar_decay, global_step)
    gbar_weight_t = GetParam(self._gbar_weight, global_step)
    if indices is not None:
      # Note - the sparse update is not easily implemented, since the
      # algorithm needs all indices of gbar to be updated
      # if mat_gbar_decay != 1 or mat_gbar_decay != 0.
      # One way to make mat_gbar_decay = 1 is by rescaling.
      # If we want the update:
      #         G_{t+1} = a_{t+1} G_t + b_{t+1} w_t
      # define:
      #         r_{t+1} = a_{t+1} * r_t
      #         h_t = G_t / r_t
      # Then:
      #         h_{t+1} = h_t + (b_{t+1} / r_{t+1}) * w_t
      # So we get the mat_gbar_decay = 1 as desired.
      # We can implement this in a future version as needed.
      # However we still need gbar_decay = 0, otherwise all indices
      # of the variable will need to be updated.
      if self._gbar_decay != 0.0:
        tf_logging.warning("Not applying momentum for variable: %s" % var.name)
      gbar_updated = grad
    else:
      gbar_updated = self._weighted_average(gbar, self._gbar_decay,
                                            gbar_decay_t,
                                            gbar_weight_t * grad)

    # Update the preconditioners and compute the preconditioned gradient
    shape = var.get_shape()
    mat_g_list = []
    for i in range(len(shape)):
      mat_g_list.append(self.get_slot(var, "Gbar_" + str(i)))
    mat_gbar_decay_t = GetParam(self._mat_gbar_decay, global_step)
    mat_gbar_weight_t = GetParam(self._mat_gbar_weight, global_step)

    preconditioned_grad = gbar_updated
    v_rank = len(mat_g_list)
    neg_alpha = - GetParam(self._alpha, global_step) / v_rank
    svd_interval = GetParam(self._svd_interval, global_step)
    precond_update_interval = GetParam(self._precond_update_interval,
                                       global_step)
    for i, mat_g in enumerate(mat_g_list):
      # axes is the list of indices to reduce - everything but the current i.
      axes = list(range(i)) + list(range(i+1, v_rank))
      if shape[i] <= self._max_matrix_size:
        # If the tensor size is sufficiently small perform full Shampoo update
        # Note if precond_update_interval > 1 and mat_gbar_decay_t != 1, this
        # is not strictly correct. However we will use it for now, and
        # fix if needed. (G_1 = aG + bg ==> G_n = a^n G + (1+a+..+a^{n-1})bg)

        # pylint: disable=g-long-lambda,cell-var-from-loop
        mat_g_updated = control_flow_ops.cond(
            math_ops.mod(global_step, precond_update_interval) < 1,
            lambda: self._update_mat_g(
                mat_g, grad, axes, mat_gbar_decay_t,
                mat_gbar_weight_t * precond_update_interval, i),
            lambda: mat_g)

        if self._svd_interval == 1:
          mat_h = self._compute_power(var, mat_g_updated, shape[i], neg_alpha)
        else:
          mat_h = control_flow_ops.cond(
              math_ops.mod(global_step, svd_interval) < 1,
              lambda: self._compute_power(var, mat_g_updated, shape[i],
                                          neg_alpha, "H_" + str(i)),
              lambda: self.get_slot(var, "H_" + str(i)))

        # mat_h is a square matrix of size d_i x d_i
        # preconditioned_grad is a d_i x ... x d_n x d_0 x ... d_{i-1} tensor
        # After contraction with a d_i x d_i tensor
        # it becomes a d_{i+1} x ... x d_n x d_0 x ... d_i tensor
        # (the first dimension is contracted out, and the second dimension of
        # mat_h is appended).  After going through all the indices, it becomes
        # a d_0 x ... x d_n tensor again.
        preconditioned_grad = math_ops.tensordot(preconditioned_grad, mat_h,
                                                 axes=([0], [0]),
                                                 name="precond_" + str(i))
      else:
        # Tensor size is too large -- perform diagonal Shampoo update
        grad_outer = math_ops.reduce_sum(grad * grad, axis=axes)
        if i == 0 and indices is not None:
          assert self._mat_gbar_decay == 1.0
          mat_g_updated = state_ops.scatter_add(mat_g, indices,
                                                mat_gbar_weight_t * grad_outer)
          mat_h = math_ops.pow(
              array_ops.gather(mat_g_updated, indices) + self._epsilon,
              neg_alpha)
        else:
          mat_g_updated = self._weighted_average(mat_g,
                                                 self._mat_gbar_decay,
                                                 mat_gbar_decay_t,
                                                 mat_gbar_weight_t * grad_outer)
          mat_h = math_ops.pow(mat_g_updated + self._epsilon, neg_alpha)

        # Need to do the transpose to ensure that the tensor becomes
        # a d_{i+1} x ... x d_n x d_0 x ... d_i tensor as described above.
        preconditioned_grad = array_ops.transpose(
            preconditioned_grad, perm=list(range(1, v_rank)) + [0]) * mat_h

    # Update the variable based on the Shampoo update
    learning_rate_t = GetParam(self._learning_rate, global_step)
    if indices is not None:
      var_updated = state_ops.scatter_add(
          var, indices, -learning_rate_t * preconditioned_grad)
    else:
      var_updated = state_ops.assign_sub(var,
                                         learning_rate_t * preconditioned_grad)
    return var_updated
Beispiel #20
0
    def time_aware_multihead_attention(
        self,
        queries,
        keys,
        key_length,
        query_length,
        t_querys,
        t_keys,
        t_querys_length,
        t_keys_length,
        num_units=None,
        num_heads=8,
        dropout_rate=0,
        is_training=True,
        scope="multihead_attention",
        reuse=None,
    ):
        '''Applies multihead attention.

        Args:
          queries: A 3d tensor with shape of [N, T_q, C_q].
          queries_length: A 1d tensor with shape of [N].
          keys: A 3d tensor with shape of [N, T_k, C_k].
          keys_length:  A 1d tensor with shape of [N].
          num_units: A scalar. Attention size.
          dropout_rate: A floating point number.
          is_training: Boolean. Controller of mechanism for dropout.
          num_heads: An int. Number of heads.
          scope: Optional scope for `variable_scope`.
          reuse: Boolean, whether to reuse the weights of a previous layer
          by the same name.

        Returns
          A 3d tensor with shape of (N, T_q, C)
        '''
        # Linear projections, C = # dim or column, T_x = # vectors or actions
        Q = tf.layers.dense(queries, num_units,
                            activation=tf.nn.relu)  # (N, T_q, C)
        #Q = tf.layers.dropout(Q, rate=dropout_rate, training=tf.convert_to_tensor(is_training))
        K = tf.layers.dense(keys, num_units,
                            activation=tf.nn.relu)  # (N, T_k, C)
        #K = tf.layers.dropout(K, rate=dropout_rate, training=tf.convert_to_tensor(is_training))
        V = tf.layers.dense(keys, num_units,
                            activation=tf.nn.relu)  # (N, T_k, C)
        #V = tf.layers.dropout(V, rate=dropout_rate, training=tf.convert_to_tensor(is_training))

        with tf.variable_scope(scope, reuse=reuse):
            # Set the fall back option for num_units
            if num_units is None:
                num_units = queries.get_shape().as_list()[-1]
            #list = t_querys.get_shape().as_list()
            #query_len = queries.get_shape().as_list()[-2]
            #key_len = queries.get_shape().as_list()[-2]

            # time decay gate
            scope = variable_scope.get_variable_scope()
            with variable_scope.variable_scope(scope,
                                               reuse=None) as unit_scope:
                with variable_scope.variable_scope(unit_scope):
                    time_input_w = variable_scope.get_variable(
                        "_time_input_w",
                        shape=[num_units, num_units],
                        dtype=queries.dtype)
                    '''
                    time_input_b = variable_scope.get_variable("_time_input_b",
                                                                shape=[t_querys_length, t_keys_length],
                                                                dtype=queries.dtype)
                    time_input_w1 = variable_scope.get_variable("_time_input_w1",
                                                               shape=[t_querys_length, t_keys_length],
                                                               dtype=queries.dtype)
                    time_input_b1 = variable_scope.get_variable("_time_input_b1",
                                                                shape=[t_querys_length, t_keys_length],
                                                                dtype=queries.dtype)
                    time_output_w1 = variable_scope.get_variable("time_output_w1",
                                                               shape=[t_querys_length, t_keys_length],
                                                               dtype=queries.dtype)
                    time_output_w2 = variable_scope.get_variable("time_output_w2",
                                                                 shape=[t_querys_length, t_keys_length],
                                                                 dtype=queries.dtype)
                    time_output_b = variable_scope.get_variable("time_output_b",
                                                               shape=[t_querys_length, t_keys_length],
                                                               dtype=queries.dtype)
                    '''
                    #time_input_b = variable_scope.get_variable("_time_input_b",
                    #shape=[t_querys_length, t_keys_length],
                    #dtype=queries.dtype)
                    time_input_w1 = variable_scope.get_variable(
                        "_time_input_w1",
                        shape=[t_querys_length, t_keys_length],
                        dtype=queries.dtype)
                    time_input_b1 = variable_scope.get_variable(
                        "_time_input_b1",
                        shape=[t_querys_length, t_keys_length],
                        dtype=queries.dtype)
                    time_output_w1 = variable_scope.get_variable(
                        "time_output_w1",
                        shape=[t_querys_length, t_keys_length],
                        dtype=queries.dtype)
                    time_output_w2 = variable_scope.get_variable(
                        "time_output_w2",
                        shape=[t_querys_length, t_keys_length],
                        dtype=queries.dtype)
                    time_output_w3 = variable_scope.get_variable(
                        "time_output_w3",
                        shape=[t_querys_length, t_keys_length],
                        dtype=queries.dtype)
                    time_output_b = variable_scope.get_variable(
                        "time_output_b",
                        shape=[t_querys_length, t_keys_length],
                        dtype=queries.dtype)
                    #time_w = variable_scope.get_variable(
                    #"_time_w", shape=[query_len, key_len], dtype=queries.dtype)
                    #time_b = variable_scope.get_variable(
                    #"_time_b", shape=[query_len, key_len], dtype=queries.dtype)
                    #time_b2 = variable_scope.get_variable(
                    # "_time_b2", shape=[query_len, key_len], dtype=queries.dtype)

            #time_query_key = tf.matmul(queries,time_input_w, name ='1')
            time_query_key = math_ops.tensordot(Q, time_input_w, [[2], [0]])
            time_query_key = tf.matmul(time_query_key,
                                       keys,
                                       transpose_b=True,
                                       name='2')
            #time_query_key = tf.nn.tanh(time_query_key+time_input_b)
            time_query_key = tf.nn.tanh(time_query_key)
            #time_query_key = tf.layers.dropout(time_query_key, rate=dropout_rate, training=tf.convert_to_tensor(is_training))
            '''
            t_querys = tf.expand_dims(t_querys,2 )
            t_querys = tf.concat([t_querys] * t_keys_length, axis=2)
            '''
            t_querys = tf.stack([t_querys] * t_keys_length, axis=2)
            '''
            t_keys = tf.expand_dims(t_keys, 1)
            t_keys = tf.concat([t_keys] * t_querys_length, axis=1)
            '''
            t_keys = tf.stack([t_keys] * t_querys_length, axis=1)

            #decay = tf.relu(time_w * tf.log((t_querys - tf.transpose(t_keys))+1)+time_b)
            decay = tf.log(tf.add(tf.abs(tf.subtract(t_querys, t_keys)), 1))
            #decay_mean = tf.reduce_sum(decay)/(t_keys_length*t_querys_length)
            #decay = decay/(decay_mean+1)
            #decay = self.normalize(decay)
            decay = tf.nn.tanh(decay * time_input_w1 + time_input_b1)
            #decay = tf.nn.tanh(decay * time_input_w1)

            #decay_gate = time_output_w1 * decay * time_query_key + time_output_b 1
            #decay_gate = time_output_w1 * decay + time_output_b 1
            # 3
            decay_gate = time_output_w1 * decay + time_output_w2 * time_query_key + time_output_b

            #decay_gate = tf.sigmoid(time_output_w1*decay*time_query_key+time_output_b)
            #decay_gate = tf.exp(-time_query_key * decay)
            #sigmoid -> exp decay 0.145 0.067
            #relu sigmoid 0.150 0.729
            #relu ->exp decay 0.1423 0.0676
            #relu-> sigmoid + 0.156
            #relu-> sigmoid + split
            #relu sigmoid time_output_w1*decay+time_output_w2*time_query_key+time_output_b
            #0.50 0.68

            # Split and concat
            Q_ = tf.concat(tf.split(Q, num_heads, axis=2),
                           axis=0)  # (h*N, T_q, C/h)
            K_ = tf.concat(tf.split(K, num_heads, axis=2),
                           axis=0)  # (h*N, T_k, C/h)
            V_ = tf.concat(tf.split(V, num_heads, axis=2),
                           axis=0)  # (h*N, T_k, C/h)
            decay_gate_ = tf.concat([decay_gate] * num_heads,
                                    axis=0)  # (h*N, T_k, C/h)
            #decay_gate_ = tf.layers.dropout(decay_gate_, rate=dropout_rate,
            #training=tf.convert_to_tensor(is_training))

            # Multiplication
            # query-key score matrix
            # each big score matrix is then split into h score matrix with same size
            # w.r.t. different part of the feature
            outputs = tf.matmul(Q_, tf.transpose(K_, [0, 2, 1]),
                                name='3')  # (h*N, T_q, T_k)
            outputs *= tf.nn.sigmoid(decay_gate_)

            # Scale
            outputs = outputs / (K_.get_shape().as_list()[-1]**0.5)

            # Key Masking
            #key_masks = tf.sign(tf.abs(tf.reduce_sum(keys, axis=-1)))# (N, T_k)
            key_masks = tf.sequence_mask(key_length,
                                         tf.shape(keys)[1])  # (N, T_k)
            key_masks = tf.tile(key_masks, [num_heads, 1])  # (h*N, T_k)
            key_masks = tf.tile(tf.expand_dims(
                key_masks, 1), [1, tf.shape(queries)[1], 1])  # (h*N, T_q, T_k)

            paddings = tf.ones_like(outputs) * (-2**32 + 1)
            #outputs = tf.where(tf.equal(key_masks, 0), outputs, paddings)  # (h*N, T_q, T_k)
            outputs = tf.where(key_masks, outputs, paddings)  # (h*N, T_q, T_k)

            # Causality = Future blinding: No use, removed

            # Activation
            outputs = tf.nn.softmax(outputs)  # (h*N, T_q, T_k)
            '''
            # Query Masking
            query_masks = tf.sign(tf.abs(tf.reduce_sum(queries, axis=-1)))  # (N, T_q)
            query_masks = tf.tile(query_masks, [num_heads, 1])  # (h*N, T_q)
            query_masks = tf.tile(tf.expand_dims(query_masks, -1), [1, 1, tf.shape(keys)[1]])  # (h*N, T_q, T_k)
            outputs *= query_masks  # broadcasting. (N, T_q, C)

            # Attention vector
            att_vec = outputs

            # Dropouts
            outputs = tf.layers.dropout(outputs, rate=dropout_rate, training=tf.convert_to_tensor(is_training))

            # Weighted sum
            outputs = tf.matmul(outputs, V_)  # ( h*N, T_q, C/h)

            # Restore shape
            outputs = tf.concat(tf.split(outputs, num_heads, axis=0), axis=2)  # (N, T_q, C)

            # Residual connection
            outputs += queries

            # Normalize
            outputs = self.normalize(outputs)  # (N, T_q, C)
            '''

            # Query Masking
            #query_masks = tf.sign(tf.abs(tf.reduce_sum(queries, axis=-1)))  # (N, T_q)
            query_masks = tf.sequence_mask(query_length,
                                           tf.shape(queries)[1],
                                           dtype=tf.float32)  # (N, T_q)
            query_masks = tf.tile(query_masks, [num_heads, 1])  # (h*N, T_q)
            query_masks = tf.tile(tf.expand_dims(query_masks, -1),
                                  [1, 1, tf.shape(keys)[1]])  # (h*N, T_q, T_k)
            outputs *= query_masks  # broadcasting. (N, T_q, C)
            print(outputs.shape.as_list())
            print(query_masks.shape.as_list())

            # Attention vector
            #########Tom Sun
            att_vec = outputs

            # Dropouts
            #outputs = tf.layers.dropout(outputs, rate=dropout_rate, training=tf.convert_to_tensor(is_training))

            # Weighted sum
            outputs = tf.matmul(outputs, V_, name='4')  # ( h*N, T_q, C/h)

            # Restore shape
            outputs = tf.concat(tf.split(outputs, num_heads, axis=0),
                                axis=2)  # (N, T_q, C)
            outputs = outputs

            # Residual connection
            outputs += queries

            # Normalize
            outputs = self.normalize(outputs)  # (N, T_q, C)

        return outputs, att_vec