def testSddmm(self, m, k, n, sparsity, use_gpu):
        # Helpers to set up the matrices.
        connector = connectors.Uniform(sparsity)
        initializer = initializers.Uniform()

        # Numpy matrices for verification.
        lhs_np = initializer([m, k])
        rhs_np = initializer([n, k])
        output_np = connector(np.ones([m, n]))

        # TensorFlow graph.
        output_topology = sparse_matrix.SparseMatrix("output",
                                                     matrix=output_np)
        lhs = tf.Variable(lhs_np, dtype=tf.float32)
        rhs = tf.Variable(rhs_np, dtype=tf.float32)
        output = ops.sddmm(lhs, rhs, output_topology, transpose_rhs=True)

        # Execute the op and compare the results.
        with self.test_session(use_gpu=use_gpu) as sess:
            sess.run(tf.global_variables_initializer())
            expected_output = self.dense_to_scipy(
                output_np * np.dot(lhs_np, np.transpose(rhs_np)))
            actual_output = self.sparse_to_scipy(*sess.run(
                [output.values, output.row_offsets, output.column_indices]),
                                                 shape=expected_output.shape)

            self.assert_sparse_matrix_equal(actual_output,
                                            expected_output,
                                            atol=1e-03,
                                            rtol=1e-05)
    def testSpmm_Replicated(self, r, m, k, n, sparsity, use_gpu):
        # Helpers to set up the matrices.
        connector = connectors.Uniform(sparsity, round_to=4)
        initializer = initializers.Uniform()

        # Numpy matrices for verification.
        mask = connector(initializer([m, k]))
        mask[mask != 0] = 1.0
        lhs_np = np.expand_dims(mask, axis=0) * initializer([r, m, k])
        rhs_np = initializer([r, k, n])

        # TensorFlow graph.
        topology = sparse_matrix.SparseTopology("topology", mask=mask)
        lhs = tf.Variable(np.reshape(lhs_np[lhs_np != 0], [r, -1]),
                          dtype=tf.float32)
        rhs = tf.Variable(rhs_np, dtype=tf.float32)
        output = ops.replicated_spmm(lhs, topology, rhs)

        # Execute the op and compare the results.
        with self.test_session(use_gpu=use_gpu) as sess:
            sess.run(tf.global_variables_initializer())
            out = sess.run(output)
            for i in range(r):
                expected_out = np.dot(lhs_np[i, :, :], rhs_np[i, :, :])
                self.assertAllClose(out[i, :],
                                    expected_out,
                                    atol=1e-03,
                                    rtol=1e-05)
    def testSparseSoftmax(self, m, n, sparsity):
        # Helpers to set up the matrices.
        connector = connectors.Uniform(sparsity)
        initializer = initializers.Uniform()

        # Numpy matrix for verification.
        matrix_np = connector(initializer([m, n]))

        # TensorFlow graph.
        matrix = sparse_matrix.SparseMatrix("input", matrix=matrix_np)
        output = ops.sparse_softmax(matrix)

        with self.test_session(use_gpu=True) as sess:
            sess.run(tf.global_variables_initializer())

            # Zero terms should not contribute to the softmax.
            matrix_np[matrix_np == 0] = -1e9

            def softmax(x):
                maxs = np.expand_dims(x.max(axis=1), axis=1)
                exps = np.exp(x - maxs)
                return exps / np.expand_dims(np.sum(exps, axis=1), axis=1)

            expected_output = self.dense_to_scipy(softmax(matrix_np))

            actual_output = self.sparse_to_scipy(
                *sess.run(
                    [output.values, output.row_offsets,
                     output.column_indices]), expected_output.shape)

            self.assert_sparse_matrix_equal(actual_output,
                                            expected_output,
                                            atol=1e-03,
                                            rtol=1e-05)
Example #4
0
    def __init__(self,
                 name,
                 shape=None,
                 matrix=None,
                 initializer=initializers.Uniform(),
                 connector=connectors.Uniform(0.8),
                 trainable=True,
                 dtype=tf.float32):
        if matrix is None:
            assert shape is not None and len(shape) == 2
            matrix = connector(initializer(shape))
            self._shape = shape
        else:
            assert shape is None
            assert len(matrix.shape) == 2
            self._shape = matrix.shape
        self._name = name
        self._trainable = trainable
        self._dtype = dtype
        self._sparsity = 1.0 - np.count_nonzero(matrix) / matrix.size

        # Create a numpy version of the sparse matrix.
        values_, row_indices_, row_offsets_, column_indices_ = _dense_to_sparse(
            matrix)

        # Create tensors for the matrix shape on the host. These are for internal
        # use and should generally not be used by end-user. Use the normal python
        # 'shape' property instead.
        with tf.device("cpu"):
            self._rows = tf.get_variable(initializer=self._shape[0],
                                         trainable=False,
                                         name=self._name + "_rows",
                                         dtype=tf.int32)
            self._columns = tf.get_variable(initializer=self._shape[1],
                                            trainable=False,
                                            name=self._name + "_columns",
                                            dtype=tf.int32)

        # Convert the sparse matrix to TensorFlow variables.
        self._values = tf.get_variable(initializer=values_,
                                       trainable=self.trainable,
                                       name=self._name + "_values",
                                       dtype=self._dtype)
        self._row_indices = tf.get_variable(initializer=row_indices_,
                                            trainable=False,
                                            name=self._name + "_row_indices",
                                            dtype=tf.uint32)
        self._row_offsets = tf.get_variable(initializer=row_offsets_,
                                            trainable=False,
                                            name=self._name + "_row_offsets",
                                            dtype=tf.uint32)
        self._column_indices = tf.get_variable(initializer=column_indices_,
                                               trainable=False,
                                               name=self._name +
                                               "_column_indices",
                                               dtype=tf.uint32)

        # Add this matrix to the collection of trainable matrices.
        track_trainable_sparse_matrix(self)
Example #5
0
    def build(self, input_shape):
        input_shape = input_shape.as_list()

        input_channels = input_shape[1]
        with tf.variable_scope(self.name, default_name="sparse_conv2d"):
            # TODO(tgale): This is a hack to make sure the sparsities
            # match exactly, not a general solution.
            sparsity = 1.0 - self.nonzeros / (self.filters * input_channels)
            self.kernel = sparse_matrix.SparseMatrix(
                "kernel", [self.filters, input_channels],
                connector=connectors.Uniform(sparsity))

            if self.use_bias:
                self.bias = tf.get_variable("bias", [self.filters])
    def testSpmmGradient(self, m, k, n, sparsity, use_gpu):
        # Helpers to set up the matrices.
        connector = connectors.Uniform(sparsity)
        initializer = initializers.Uniform()

        # Numpy matrices for verification.
        lhs_np = connector(initializer([m, k]))
        rhs_np = initializer([k, n])

        lhs = sparse_matrix.SparseMatrix("lhs", matrix=lhs_np)
        rhs = tf.Variable(rhs_np, dtype=tf.float32)
        output = ops.spmm(lhs, rhs)

        with self.test_session(use_gpu=use_gpu) as sess:
            sess.run(tf.global_variables_initializer())
            error = tf.test.compute_gradient_error(
                [lhs.values, rhs], [lhs.values.shape.as_list(), [k, n]],
                output, [m, n])
            self.assertLess(error, 1e-3)
Example #7
0
  def testCreateMatrix(self, m, n, sparsity):
    matrix = sparse_matrix.SparseMatrix(
        "matrix", [m, n], connector=connectors.Uniform(sparsity))

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      values, row_indices, row_offsets, column_indices = sess.run([
          matrix.values, matrix.row_indices, matrix.row_offsets,
          matrix.column_indices
      ])

      # Check the shape of the matrix.
      self.assertLen(values.shape, 1)
      self.assertLen(row_indices.shape, 1)
      self.assertLen(row_offsets.shape, 1)
      self.assertLen(column_indices.shape, 1)

      # Check the sparsity matches the target.
      target_nonzeros = m * n - int(round(sparsity * m * n))
      self.assertEqual(values.shape[0], target_nonzeros)
    def testSpmm(self, m, k, n, sparsity, use_gpu):
        # Helpers to set up the matrices.
        connector = connectors.Uniform(sparsity)
        initializer = initializers.Uniform()

        # Numpy matrices for verification.
        lhs_np = connector(initializer([m, k]))
        rhs_np = initializer([k, n])

        # TensorFlow graph.
        lhs = sparse_matrix.SparseMatrix("lhs", matrix=lhs_np)
        rhs = tf.Variable(rhs_np, dtype=tf.float32)
        output = ops.spmm(lhs, rhs)

        # Execute the op and compare the results.
        with self.test_session(use_gpu=use_gpu) as sess:
            sess.run(tf.global_variables_initializer())
            self.assertAllClose(sess.run(output),
                                np.dot(lhs_np, rhs_np),
                                atol=1e-03,
                                rtol=1e-05)
Example #9
0
    def testSddmm_Replicated(self, r, m, k, n, sparsity, force_gpu):
        # Helpers to set up the matrices.
        connector = connectors.Uniform(sparsity)
        initializer = initializers.Uniform()

        # Numpy matrices for verification.
        lhs_np = initializer([r, m, k])
        rhs_np = initializer([r, n, k])
        output_np = connector(np.ones([m, n]))

        # TensorFlow graph.
        output_topology = sparse_matrix.SparseTopology("output_topology",
                                                       mask=output_np)
        lhs = tf.Variable(lhs_np, dtype=tf.float32)
        rhs = tf.Variable(rhs_np, dtype=tf.float32)
        output = ops.replicated_sddmm(lhs,
                                      rhs,
                                      output_topology,
                                      transpose_rhs=True)

        # Execute the op and compare the results.
        with self.test_session(force_gpu=force_gpu) as sess:
            sess.run(tf.global_variables_initializer())

            # Run the replicated sddmm.
            v, ro, ci = sess.run([
                output, output_topology.row_offsets,
                output_topology.column_indices
            ])

            for i in range(r):
                expected_output = self.dense_to_scipy(
                    output_np *
                    np.dot(lhs_np[i, :, :], np.transpose(rhs_np[i, :, :])))
                actual_output = self.sparse_to_scipy(
                    v[i, :], ro, ci, shape=expected_output.shape)
                self.assert_sparse_matrix_equal(actual_output,
                                                expected_output,
                                                atol=1e-03,
                                                rtol=1e-05)
Example #10
0
    def testSparseSoftmax_Replicated(self, r, m, n, sparsity):
        # Helpers to set up the matrices.
        connector = connectors.Uniform(sparsity)
        initializer = initializers.Uniform()

        # Numpy matrix for verification.
        mask = connector(np.ones([m, n]))
        matrix_np = np.expand_dims(mask, axis=0) * initializer([r, m, n])

        # TensorFlow graph.
        topology = sparse_matrix.SparseTopology("topology", mask=mask)
        values = tf.Variable(np.reshape(matrix_np[matrix_np != 0], [r, -1]),
                             dtype=tf.float32)
        output = ops.replicated_sparse_softmax(values, topology)

        with self.test_session(use_gpu=True) as sess:
            sess.run(tf.global_variables_initializer())
            v, ro, ci = sess.run(
                [output, topology.row_offsets, topology.column_indices])

            # Zero terms should not contribute to the softmax.
            matrix_np[matrix_np == 0] = -1e9

            def softmax(x):
                maxs = np.expand_dims(x.max(axis=1), axis=1)
                exps = np.exp(x - maxs)
                return exps / np.expand_dims(np.sum(exps, axis=1), axis=1)

            for i in range(r):
                expected_output = self.dense_to_scipy(
                    softmax(matrix_np[i, :, :]))

                actual_output = self.sparse_to_scipy(v[i, :], ro, ci,
                                                     expected_output.shape)
                self.assert_sparse_matrix_equal(actual_output,
                                                expected_output,
                                                atol=1e-03,
                                                rtol=1e-05)
Example #11
0
    def testCsr2Idx(self, m, n, sparsity, use_gpu):
        # Helpers to set up the matrices.
        connector = connectors.Uniform(sparsity)
        initializer = initializers.Uniform()

        # Numpy matrix for verification.
        matrix_np = connector(initializer([m, n]))

        # TensorFlow graph.
        matrix = sparse_matrix.SparseMatrix("input", matrix=matrix_np)
        output = ops.csr2idx(matrix)

        # Execute the op and compare the results.
        with self.test_session(use_gpu=use_gpu) as sess:
            sess.run(tf.global_variables_initializer())

            # Calculate the linear indices in numpy.
            x = self.dense_to_scipy(matrix_np)
            expected_output = np.concatenate([
                x.indices[x.indptr[i]:x.indptr[i + 1]] + i * n
                for i in range(m)
            ])
            self.assertAllEqual(sess.run(output), expected_output)
Example #12
0
    def testTranspose(self, m, n, sparsity, use_gpu):
        # Helpers to set up the matrices.
        connector = connectors.Uniform(sparsity)
        initializer = initializers.Uniform()

        # Numpy matrix for verification.
        matrix_np = connector(initializer([m, n]))

        # TensorFlow graph.
        matrix = sparse_matrix.SparseMatrix("input", matrix=matrix_np)
        output = ops.transpose(matrix)

        # Execute the op and compare the results.
        with self.test_session(use_gpu=use_gpu) as sess:
            sess.run(tf.global_variables_initializer())
            expected_output = self.dense_to_scipy(np.transpose(matrix_np))
            actual_output = self.sparse_to_scipy(*sess.run(
                [output.values, output.row_offsets, output.column_indices]),
                                                 shape=expected_output.shape)

            self.assert_sparse_matrix_equal(actual_output,
                                            expected_output,
                                            atol=1e-03,
                                            rtol=1e-05)
Example #13
0
    def testSddmmGradient(self, m, k, n, sparsity, use_gpu):
        # Helpers to set up the matrices.
        connector = connectors.Uniform(sparsity)
        initializer = initializers.Uniform()

        # Numpy matrices for verification.
        lhs_np = initializer([m, k])
        rhs_np = initializer([n, k])
        output_np = connector(np.ones([m, n]))

        # TensorFlow graph.
        output_topology = sparse_matrix.SparseMatrix("output",
                                                     matrix=output_np)
        lhs = tf.Variable(lhs_np, dtype=tf.float32)
        rhs = tf.Variable(rhs_np, dtype=tf.float32)
        output = ops.sddmm(lhs, rhs, output_topology, transpose_rhs=True)

        # Execute the op and compare the results.
        with self.test_session(use_gpu=use_gpu) as sess:
            sess.run(tf.global_variables_initializer())
            error = tf.test.compute_gradient_error(
                [lhs, rhs], [[m, k], [n, k]], output.values,
                output.values.shape.as_list())
            self.assertLess(error, 1e-3)
Example #14
0
  def testDenseToSparse(self, m, n, sparsity):
    # Helpers to set up the matrices.
    connector = connectors.Uniform(sparsity)
    initializer = initializers.Uniform()

    # Create a dense matrix in numpy with the specified sparsity.
    matrix = connector(initializer([m, n]))

    # Convert to a sparse numpy matrix.
    values, row_indices, row_offsets, column_indices = sparse_matrix._dense_to_sparse(
        matrix)

    # Create a scipy version of the matrix.
    expected_output = scipy.sparse.csr_matrix(
        (values, column_indices, row_offsets), [m, n])

    # Create the expected row indices.
    expected_row_indices = np.argsort(-1 * np.diff(expected_output.indptr))

    # Compare the matrices.
    self.assertAllEqual(expected_output.data, values)
    self.assertAllEqual(expected_output.indptr, row_offsets)
    self.assertAllEqual(expected_output.indices, column_indices)
    self.assertAllEqual(expected_row_indices, row_indices)
Example #15
0
def transformer_decoder(decoder_input,
                        encoder_output,
                        decoder_self_attention_bias,
                        encoder_decoder_attention_bias,
                        hparams,
                        cache=None,
                        decode_loop_step=None,
                        name="decoder",
                        nonpadding=None,
                        save_weights_to=None,
                        make_image_summary=True,
                        losses=None,
                        layer_collection=None,
                        recurrent_memory_by_layer=None,
                        chunk_number=None):
    """A stack of transformer layers.

  Args:
    decoder_input: a Tensor
    encoder_output: a Tensor
    decoder_self_attention_bias: bias Tensor for self-attention (see
      common_attention.attention_bias())
    encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention
      (see common_attention.attention_bias())
    hparams: hyperparameters for model
    cache: dict, containing tensors which are the results of previous
      attentions, used for fast decoding.
    decode_loop_step: An integer, step number of the decoding loop. Only used
      for inference on TPU.
    name: a string
    nonpadding: optional Tensor with shape [batch_size, encoder_length]
      indicating what positions are not padding.  This is used to mask out
      padding in convolutional layers.  We generally only need this mask for
      "packed" datasets, because for ordinary datasets, no padding is ever
      followed by nonpadding.
    save_weights_to: an optional dictionary to capture attention weights for
      visualization; the weights tensor will be appended there under a string
      key created from the variable scope (including name).
    make_image_summary: Whether to make an attention image summary.
    losses: optional list onto which to append extra training losses
    layer_collection: A tensorflow_kfac.LayerCollection. Only used by the KFAC
      optimizer. Default is None.
    recurrent_memory_by_layer: Optional dict, mapping layer names to instances
      of transformer_memory.RecurrentMemory. Default is None.
    chunk_number: an optional integer Tensor with shape [batch] used to operate
      the recurrent_memory.

  Returns:
    y: a Tensors
  """
    x = decoder_input

    if hparams.sparse_attention_mode == "sparse":
        # If we want to run with our actual sparse kernels, intercept
        # the self_attention_type and replace it with our attention fn.
        seqlen = common_layers.shape_list(x)[1]
        sparse_attention_topology = sparse_matrix.SparseTopology(
            "sparse_attention", [seqlen, seqlen],
            connector=connectors.Uniform(0.955411645))  # 0.955411659
        hparams.self_attention_type = functools.partial(
            hparams.self_attention_type, topology=sparse_attention_topology)
    elif hparams.sparse_attention_mode == "masked":
        # If we're training with sparse attention, create the per-layer
        # attention bias that describes the sparsity pattern.
        #
        # NOTE: We share the same pattern across all attention heads
        # within a layer due to memory constraints (because we're not
        # actually training with sparse kernels). Per-head patterns
        # would likely perform better.
        #
        # NOTE: We also share the same pattern across all layers, as
        # protobuf can't save all of these large tensors if we create
        # more than one of them.
        decoder_self_attention_bias = generate_sparse_attention_mask(
            common_layers.shape_list(x)[1], hparams, 0)
        tf.logging.info("Generated sparse attention mask.")
    elif hparams.sparse_attention_mode == "dense":
        # Replace the dot-product attention with our memory efficient
        # version.
        hparams.self_attention_type = functools.partial(
            hparams.self_attention_type, bias=decoder_self_attention_bias)
        pass
    else:
        # For training on TPU, use T2T's standard attention.
        assert hparams.sparse_attention_mode is None

    with tf.variable_scope(name):
        for layer_idx in range(hparams.num_decoder_layers
                               or hparams.num_hidden_layers):
            x = transformer.transformer_decoder_layer(
                x,
                decoder_self_attention_bias,
                layer_idx,
                hparams,
                encoder_decoder_attention_bias=encoder_decoder_attention_bias,
                encoder_output=encoder_output,
                cache=cache,
                decode_loop_step=decode_loop_step,
                nonpadding=nonpadding,
                save_weights_to=save_weights_to,
                make_image_summary=make_image_summary,
                losses=losses,
                layer_collection=layer_collection,
                recurrent_memory_by_layer=recurrent_memory_by_layer,
                chunk_number=chunk_number)

        # if normalization is done in layer_preprocess, then it should also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.
        return common_layers.layer_preprocess(
            x, hparams, layer_collection=layer_collection)