示例#1
0
  def testReturnsTopoSort(self, scheduler_alg):
    mtf_graph = mtf.Graph()
    mesh = mtf.Mesh(mtf_graph, 'my_mesh')
    x = mtf.Constant(mesh, 0,
                     shape=mtf.convert_to_shape('a:3,b:4'),
                     dtype=tf.int32,
                     name='X').outputs[0]
    y = mtf.Constant(mesh, 0,
                     shape=mtf.convert_to_shape('b:4,c:5'),
                     dtype=tf.int32,
                     name='Y').outputs[0]
    mtf.EinsumOperation([x, y], mtf.convert_to_shape('a:3,c:5'), name='Z1')
    mtf.EinsumOperation([x, y], mtf.convert_to_shape('a:3,c:5'), name='Z2')

    graph = graph_interface.GraphInterface(mtf_graph)
    graph.set_tensor_final('Z1:0')
    graph.set_tensor_final('Z2:0')
    schedule = list(scheduler.minimize_peak_memory(graph, scheduler_alg))

    self.assertCountEqual(schedule[0:2], [0, 1])
    self.assertCountEqual(schedule[2:4], [2, 3])
示例#2
0
  def testMinimizePeakMemoryList_SingleUseTensor(self):
    mtf_graph = mtf.Graph()
    mesh = mtf.Mesh(mtf_graph, 'my_mesh')
    mtf.Constant(mesh, 0, shape=mtf.convert_to_shape('a:4'), dtype=tf.int32,
                 name='X')
    y = mtf.Constant(mesh, 0, shape=mtf.convert_to_shape('b:3'), dtype=tf.int32,
                     name='Y').outputs[0]
    mtf.BroadcastOperation(y, mtf.convert_to_shape('b:3,c:2'), name='Z')

    graph = graph_interface.GraphInterface(mtf_graph)
    graph.set_tensor_final('X:0')
    graph.set_tensor_final('Z:0')
    schedule = list(scheduler.minimize_peak_memory(graph, 'LIST'))
    # When nothing is scheduled:
    #   X frees -4 entries
    #   Y frees -3 entries
    # After [Y] scheduled:
    #   X frees -4 entries
    #   Z frees -3 entries
    # Hence the schedule should be [Y, Z, X].
    self.assertEqual(schedule, [1, 2, 0])
示例#3
0
  def setUp(self):
    super(OperationSplittabilityTest, self).setUp()
    self.graph = mtf.Graph()
    self.mesh = mtf.Mesh(self.graph, "my_mesh")

    self.a_dim = mtf.Dimension("a", 5)
    self.b_dim = mtf.Dimension("b", 10)
    self.c_dim = mtf.Dimension("c", 15)

    self.ab_shape = mtf.Shape([self.a_dim, self.b_dim])
    self.x = mtf.zeros(self.mesh, self.ab_shape)

    self.batch_dim = mtf.Dimension("batch", 100)
    self.grid_h_dim = mtf.Dimension("grid_h", 10)
    self.grid_w_dim = mtf.Dimension("grid_w", 10)
    self.filter_h_dim = mtf.Dimension("filter_h", 5)
    self.filter_w_dim = mtf.Dimension("filter_w", 5)
    self.in_dim = mtf.Dimension("in", 10)
    self.out_dim = mtf.Dimension("out", 10)
    self.image = mtf.zeros(self.mesh, [self.batch_dim, self.grid_h_dim,
                                       self.grid_w_dim, self.in_dim])
示例#4
0
    def setUp(self):
        super(LayoutValidatorTest, self).setUp()
        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")

        a_dim = mtf.Dimension("a", 5)
        b_dim = mtf.Dimension("b", 10)
        concat_dim1 = mtf.Dimension("concat", 15)
        concat_dim2 = mtf.Dimension("concat", 20)

        x1 = mtf.zeros(mesh, mtf.Shape([a_dim, b_dim, concat_dim1]))
        x2 = mtf.zeros(mesh, mtf.Shape([a_dim, b_dim, concat_dim2]))
        mtf.ConcatOperation([x1, x2], "concat")

        # We add a tensor with anonymous shape, which is supposed to be
        # unsplittable (i.e. none of its dimensions show up during
        # test_SplittableMtfDimensionNames).
        _ = mtf.zeros(mesh, mtf.anonymous_shape(mtf.Shape([a_dim, b_dim])))

        mesh_shape = mtf.Shape([("m1", 4), ("m2", 2)])
        self.valid_layouts = valid_layouts.LayoutValidator(graph, mesh_shape)
def WrongShape(in_tsr):
    graph = mtf.Graph()
    mesh0 = mtf.Mesh(graph, 'mesh0')
    mesh1 = mtf.Mesh(graph, 'mesh1')
    mesh_to_impl = {mesh0:GetMeshImpl([4, 2]), \
            mesh1:GetMeshImpl([8])}

    shape = in_tsr.get_shape().as_list()
    mtf_shape = GetShape(shape[:-2] +
                         [('axis0', shape[2]), ('axis0', shape[3])])
    mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape)
    mtf_out_tsr = mt.ReplaceMeshWithIndependentAxes(
        mtf_in_tsr, mesh1,
        [RandName(), 'axis0', RandName(),
         RandName()])

    try:
        Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr)
        assert False  # This test should fail with ValueError
    except ValueError:
        return
    def test_variable_placer(self):
        sizes = [100, 0, 0, 0]
        device_list = ['cpu:0', 'cpu:1', 'cpu:2', 'cpu:3']

        with tf.Graph().as_default() as g:
            var_placer = mtf.utils.BalancedVariablePlacer(device_list, sizes)
            graph = mtf.Graph()
            mesh = mtf.Mesh(graph, 'my_mesh', var_placer)

            hidden_dim = mtf.Dimension('hidden', 10)
            output_dim = mtf.Dimension('output_feature', 10)

            for i in xrange(5):
                # Each variable takes 400 Bytes, and will be placed from cpu:1.
                mtf.get_variable(mesh, 'w{}'.format(i),
                                 [hidden_dim, output_dim])

            for i in xrange(5):
                var = g.get_tensor_by_name('w{}:0'.format(i))
                device = (i + 1) % len(device_list)
                self.assertEqual('cpu:{}'.format(device), var.device)
示例#7
0
  def testMaskedLocalAttention1D(self, batch, length, io_channels, kv_channels,
                                 heads, block_length):
    length_q = length
    length_m = length
    query = tf.random_normal([batch, length_q, io_channels])
    memory = tf.random_normal([batch, length_m, io_channels])

    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    batch_dim = mtf.Dimension("batch", batch)
    length_q_dim = mtf.Dimension("length_q", length_q)
    length_m_dim = mtf.Dimension("length_m", length_m)
    io_channels_dim = mtf.Dimension("io_channels", io_channels)
    kv_channels_dim = mtf.Dimension("kv_channels", kv_channels)
    heads_dim = mtf.Dimension("heads", heads)

    mtf_query = mtf.import_tf_tensor(
        mesh, query,
        shape=mtf.Shape([batch_dim, length_q_dim, io_channels_dim]))
    mtf_memory = mtf.import_tf_tensor(
        mesh, memory,
        shape=mtf.Shape([batch_dim, length_m_dim, io_channels_dim]))
    mtf_outputs = mtf.layers.masked_local_attention_1d(
        mtf_query,
        mtf_memory,
        kv_channels=kv_channels_dim,
        heads=heads_dim,
        block_length=block_length)
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        shape=[], layout={}, devices=[""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)

    tf_group = lowering.copy_masters_to_slices()
    init = tf.global_variables_initializer()
    self.evaluate(init)
    self.evaluate(tf_group)
    actual = self.evaluate(actual_outputs)

    self.assertEqual(actual.shape, (batch, length_q, io_channels))
示例#8
0
def main(_):

  # Creating layout and mesh implementation
  mesh_shape = [("row", FLAGS.nx), ("col", FLAGS.ny)]
  layout_rules = [("nx_lr", "row"), ("ny_lr", "col"), ("nx", "row"),
                  ("ny", "col"), ("ty", "row"), ("tz", "col"), ("ty_lr", "row"),
                  ("tz_lr", "col"), ("nx_block", "row"), ("ny_block", "col")]
  mesh_impl = HvdSimdMeshImpl(
      mtf.convert_to_shape(mesh_shape),
      mtf.convert_to_layout_rules(layout_rules))

  # Create the graph and mesh
  graph = mtf.Graph()
  mesh = mtf.Mesh(graph, "my_mesh")

  ## Load initial power spectrum
  klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
  plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]

  # Defines the computational graph for the nbody
  initial_conditions, final_field = nbody_fn(mesh, klin, plin)

  # Lower mesh computation
  lowering = mtf.Lowering(graph, {mesh: mesh_impl})

  # Retrieve fields as tf tensors
  tf_initc = lowering.export_to_tf_tensor(initial_conditions)
  tf_final = lowering.export_to_tf_tensor(final_field)

  with tf.Session() as sess:
    start = time.time()
    init_conds, final = sess.run([tf_initc, tf_final])
    end = time.time()
    print('\n Time for the mesh run : %f \n' % (end - start))

  # Export these fields
  np.save('simulation_output_%d.npy' % comm.Get_rank(), final)
  np.save('simulation_input_%d.npy' % comm.Get_rank(), init_conds)

  exit(0)
示例#9
0
  def testWeightsNonzero(self):
    inputs = tf.constant([[3, 1, 0], [1, 0, 0]])

    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    batch_dim = mtf.Dimension("batch", inputs.shape.as_list()[0])
    channels_dim = mtf.Dimension("channels", inputs.shape.as_list()[1])

    mtf_inputs = mtf.import_tf_tensor(
        mesh, inputs, shape=mtf.Shape([batch_dim, channels_dim]))
    mtf_outputs = mtf.layers.weights_nonzero(mtf_inputs)
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        shape=[], layout={}, devices=[""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)

    expected_outputs = common_layers.weights_nonzero(inputs)
    tf_group = lowering.copy_masters_to_slices()
    self.evaluate(tf_group)
    actual, expected = self.evaluate([actual_outputs, expected_outputs])

    self.assertAllEqual(actual, expected)
示例#10
0
    def testNthSmallestReduceSecondDim(self):
        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        a_dim = mtf.Dimension("a", 6)
        b_dim = mtf.Dimension("b", 2)
        inputs = tf.constant([[1, 10], [2, 9], [3, 8], [4, 7], [5, 6], [6, 5]])
        n = 0  # find smallest element (n is zero-indexed)
        reduced_dim = b_dim
        expected_outputs = tf.constant([1, 2, 3, 4, 5, 5])

        mtf_inputs = mtf.import_tf_tensor(mesh,
                                          inputs,
                                          shape=mtf.Shape([a_dim, b_dim]))
        mtf_outputs = mtf.nth_smallest_element(mtf_inputs, n, reduced_dim,
                                               "test_nth_smallest")
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape="all:2",
                                                              layout="a:all",
                                                              devices=["", ""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)
        self.assertAllEqual(self.evaluate(actual_outputs),
                            self.evaluate(expected_outputs))
示例#11
0
文件: demo.py 项目: EiffL/mesh_demo
def model_fn(nc=64, batch_size=1):
    """
    Example of function implementing a CNN and returning a value.
    """

    # Create the mesh TF graph
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")

    # Define the named dimensions
    n_block_x = 4
    n_block_y = 2
    n_block_z = 1

    batch_dim = mtf.Dimension("batch", batch_size`)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc//n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc//n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc//n_block_z)

    image_c_dim = mtf.Dimension('image_c', 3)
    hidden_dim  = mtf.Dimension('h', 128)

    # Create some input data
    data = mtf.random_uniform(mesh, [batch_dim, nx_dim, ny_dim, nz_dim,
                                                sx_dim, sy_dim, sz_dim,
                                                image_c_dim])

    net = mtf.layers.conv3d_with_blocks(data, hidden_dim,
        filter_size=(3, 3, 3), strides=(1, 1, 1), padding='SAME',
        d_blocks_dim=nx_dim, h_blocks_dim=ny_dim)

    net = mtf.reduce_sum(net, output_shape=[batch_dim, hidden_dim] )

    return net
示例#12
0
  def testBatchNorm(self):
    batch = 2
    channels = 3
    inputs = tf.constant([[0, 1, 2], [4, 5, 6]], dtype=np.float32)

    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    batch_dim = mtf.Dimension("batch", batch)
    channels_dim = mtf.Dimension("channels", channels)

    mtf_inputs = mtf.import_tf_tensor(
        mesh, inputs, shape=mtf.Shape([batch_dim, channels_dim]))

    mtf_outputs_0, _ = mtf.layers.batch_norm(
        mtf_inputs,
        is_training=True, momentum=0.95, epsilon=1e-6,
        dims_idx_start=0, dims_idx_end=1, name="bn0")
    mtf_outputs_1, _ = mtf.layers.batch_norm(
        mtf_outputs_0 * 2 + 1,
        is_training=True, momentum=0.95, epsilon=1e-6,
        dims_idx_start=0, dims_idx_end=1, name="bn1")

    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        shape=[], layout={}, devices=[""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})

    actual_outputs_0 = lowering.export_to_tf_tensor(mtf_outputs_0)
    actual_outputs_1 = lowering.export_to_tf_tensor(mtf_outputs_1)

    tf_group = lowering.copy_masters_to_slices()
    init = tf.global_variables_initializer()
    self.evaluate(init)
    self.evaluate(tf_group)
    [actual_0, actual_1] = self.evaluate([actual_outputs_0, actual_outputs_1])

    expected = np.array([[-1, -1, -1], [1, 1, 1]])
    self.assertAllClose(actual_0, expected)
    self.assertAllClose(actual_1, expected)
示例#13
0
    def testMultiheadAttention(self, kv_channels, heads):
        batch = 2
        length = 8
        channels = 3
        query = tf.random_normal([batch, length, channels])

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        batch_dim = mtf.Dimension("batch", batch)
        length_dim = mtf.Dimension("length", length)
        channels_dim = mtf.Dimension("channels", channels)
        kv_channels_dim = mtf.Dimension("kv_channels", kv_channels)
        heads_dim = mtf.Dimension("heads", heads)

        mtf_query = mtf.import_tf_tensor(
            mesh,
            query,
            shape=mtf.Shape([batch_dim, length_dim, channels_dim]))
        mtf_outputs = mtf.layers.multihead_attention(
            mtf_query,
            memory_antecedent=None,
            mask=None,
            kv_channels=kv_channels_dim,
            heads=heads_dim,
            is_training=False)
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)

        tf_group = lowering.copy_masters_to_slices()
        init = tf.global_variables_initializer()
        self.evaluate(init)
        self.evaluate(tf_group)
        actual = self.evaluate(actual_outputs)

        self.assertEqual(actual.shape, query.shape)
示例#14
0
    def testLayoutAndMeshShape(self):
        # Same as previous test, but don't specify a 4x2 mesh.
        mtf_graph = mtf.Graph()
        mesh = mtf.Mesh(mtf_graph, "my_mesh")
        x = mtf.zeros(mesh, "a:10,b:5")
        y = mtf.zeros(mesh, "b:5,c:20")
        z = mtf.einsum([x, y], "a:10,c:20")

        layout, mesh_shape = mtf.auto_mtf.layout_and_mesh_shape(
            mtf_graph, 8, [z])

        a_dim = mtf.convert_to_dimension(("a", 10))
        b_dim = mtf.convert_to_dimension(("b", 5))
        c_dim = mtf.convert_to_dimension(("c", 20))

        self.assertEqual(
            layout.tensor_dimension_to_mesh_axis(a_dim, mesh_shape), 1)
        self.assertIsNone(
            layout.tensor_dimension_to_mesh_axis(b_dim, mesh_shape))
        self.assertEqual(
            layout.tensor_dimension_to_mesh_axis(c_dim, mesh_shape), 0)

        self.assertCountEqual(
            mesh_shape.dims,
            [mtf.Dimension("mesh_0", 4),
             mtf.Dimension("mesh_1", 2)])

        layout, mesh_shape = mtf.auto_mtf.layout_and_mesh_shape(
            mtf_graph, 8, [z], 1)

        self.assertIsNone(
            layout.tensor_dimension_to_mesh_axis(a_dim, mesh_shape))
        self.assertIsNone(
            layout.tensor_dimension_to_mesh_axis(b_dim, mesh_shape))
        self.assertIsNone(
            layout.tensor_dimension_to_mesh_axis(c_dim, mesh_shape))

        self.assertCountEqual(mesh_shape.dims, [mtf.Dimension("mesh_0", 8)])
示例#15
0
def test_model():
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")

    seq_len = params["n_ctx"]

    batch_dim = mtf.Dimension("batch", 1)
    sequence_dim = mtf.Dimension("sequence", seq_len)

    features = {
        'inputs': mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32),
        'labels': mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32)
    }    

    # create mask

    num_mem_kv = params.get('num_mem_kv', 0)
    length_dim = mtf.Dimension('sequence', seq_len)
    memory_length_dim = mtf.Dimension('memory_length', seq_len + num_mem_kv)
    embed_sequence_dim = mtf.Dimension('embed_sequence', seq_len)
    embd_dim = mtf.Dimension("embd", params["n_embd"])
    vocab_dim = mtf.Dimension("vocab", params["n_vocab"])

    other_features = {}
    variable_dtype = mtf.VariableDType(tf.float32, tf.float32, tf.float32)

    other_features["attn_bias"] = biasmask_attn_weights(mesh, length_dim, memory_length_dim, variable_dtype)
    other_features["embd_dim"] = embd_dim
    other_features["vocab_dim"] = vocab_dim
    other_features["embed_sequence_dim"] = embed_sequence_dim
    other_features["memory_length_dim"] = memory_length_dim

    with not_raises(Exception):
        logits, _, _ = gpt2.model(features, other_features, params, mesh, variable_dtype=variable_dtype)

        mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        logits = lowering.export_to_tf_tensor(logits)
示例#16
0
    def testDense(self, units, use_bias, new_dim_name):
        batch = 2
        channels = 3
        inputs = tf.random_normal([batch, channels])

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        batch_dim = mtf.Dimension("batch", batch)
        channels_dim = mtf.Dimension("channels", channels)
        new_dim = mtf.Dimension(new_dim_name, units)

        mtf_inputs = mtf.import_tf_tensor(mesh,
                                          inputs,
                                          shape=mtf.Shape(
                                              [batch_dim, channels_dim]))
        mtf_outputs = mtf.layers.dense(
            mtf_inputs,
            new_dims=new_dim,
            reduced_dims=[channels_dim],
            activation=mtf.relu,
            use_bias=use_bias,
        )
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)

        expected_outputs = tf.keras.layers.Dense(units=units,
                                                 activation=tf.nn.relu,
                                                 use_bias=use_bias)(inputs)
        tf_group = lowering.copy_masters_to_slices()
        init = tf.global_variables_initializer()
        self.evaluate(init)
        self.evaluate(tf_group)
        actual, expected = self.evaluate([actual_outputs, expected_outputs])

        self.assertEqual(actual.shape, expected.shape)
示例#17
0
    def testDynamicText2self(self):
        batch = 2
        length = 5
        input_tensors = {
            "inputs": [[3, 1, 4, 1, 0], [1, 4, 3, 2, 1]],
            "inputs_segmentation": [[1, 1, 2, 2, 0], [1, 2, 2, 2, 2]],
            "inputs_position": [[0, 1, 0, 1, 0], [0, 0, 1, 2, 3]],
            "targets": [[1, 1, 0, 0, 0], [9, 8, 1, 2, 1]],
            "targets_segmentation": [[1, 2, 0, 0, 0], [1, 1, 1, 2, 2]],
            "targets_position": [[0, 0, 0, 0, 0], [0, 1, 2, 0, 1]]
        }
        expected_output_tensors = {
            "targets": [[3, 1, 1, 4, 1, 1, 0, 0, 0, 0],
                        [1, 9, 8, 1, 4, 3, 2, 1, 2, 1]],
            "targets_segmentation": [[1, 1, 1, 2, 2, 2, 0, 0, 0, 0],
                                     [1, 1, 1, 1, 2, 2, 2, 2, 2, 2]],
            "targets_position": [[0, 1, 2, 0, 1, 2, 0, 0, 0, 0],
                                 [0, 1, 2, 3, 0, 1, 2, 3, 4, 5]]
        }
        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        batch_dim = mtf.Dimension("batch", batch)
        length_dim = mtf.Dimension("length", length)

        input_shape = mtf.Shape([batch_dim, length_dim])
        mtf_features = {
            k: mtf.import_tf_tensor(mesh, v, input_shape)
            for k, v in input_tensors.items()
        }
        mtf_outputs = utils._dynamic_text2self(mtf_features)
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        for k, v in expected_output_tensors.items():
            out = lowering.export_to_tf_tensor(mtf_outputs[k])
            actual = self.evaluate(out)
            self.assertAllEqual(actual, v)
示例#18
0
def test_sampling():
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")

    batch_dim = mtf.Dimension("batch", 1)
    sequence_dim = mtf.Dimension("sequence", 1)

    inputs = mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32)
    inputs = mtf.pad(inputs, [0, 3], sequence_dim.name)

    # create mask

    seq_len = params["n_ctx"]
    num_mem_kv = params.get('num_mem_kv', 0)
    length_dim = mtf.Dimension('sequence', seq_len)
    memory_length_dim = mtf.Dimension('memory_length', seq_len + num_mem_kv)
    embed_sequence_dim = mtf.Dimension('embed_sequence', seq_len)
    embd_dim = mtf.Dimension("embd", params["n_embd"])
    vocab_dim = mtf.Dimension("vocab", params["n_vocab"])

    other_features = {}

    other_features["attn_bias"] = biasmask_attn_weights(mesh, length_dim, memory_length_dim, mtf.VariableDType(tf.float32))
    other_features["embd_dim"] = embd_dim
    other_features["vocab_dim"] = vocab_dim
    other_features["embed_sequence_dim"] = embed_sequence_dim
    other_features["memory_length_dim"] = memory_length_dim

    params["mode"] = "predict"

    with not_raises(Exception):
        samples = sample_autoregressive(
            inputs, other_features=other_features, params=params, variable_dtype=mtf.VariableDType(),
            remove_partial_sequences=params["remove_partial_sequences"], stop_at_token=params["eos_id"], sampling_use_entmax=True)

        mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        samples = lowering.export_to_tf_tensor(samples)
示例#19
0
  def testLayout(self):
    # Construct a Mesh TensorFlow graph and mesh.
    mtf_graph = mtf.Graph()
    mesh = mtf.Mesh(mtf_graph, "my_mesh")
    x = mtf.zeros(mesh, "a:10,b:5")
    y = mtf.zeros(mesh, "b:5,c:20")
    z = mtf.einsum([x, y], "a:10,c:20")

    # Decide on a mesh shape.
    mesh_shape = mtf.convert_to_shape("m1:4,m2:2")

    # Compute a layout based on the graph and mesh.
    # Note that knowing the identity of the outputs is important to the
    # optimization since they cannot be freed.
    layout = mtf.auto_mtf.layout(mtf_graph, mesh_shape, [z])

    a_dim = mtf.convert_to_dimension(("a", 10))
    b_dim = mtf.convert_to_dimension(("b", 5))
    c_dim = mtf.convert_to_dimension(("c", 20))

    self.assertEqual(layout.tensor_dimension_to_mesh_axis(a_dim, mesh_shape), 1)
    self.assertIsNone(layout.tensor_dimension_to_mesh_axis(b_dim, mesh_shape))
    self.assertEqual(layout.tensor_dimension_to_mesh_axis(c_dim, mesh_shape), 0)
示例#20
0
 def testGraph(self):
     graph = mtf.Graph()
     self.assertEmpty(graph.operations)
     self.assertEmpty(graph.trainable_variables)
     self.assertEmpty(graph.all_variables)
     mesh = mtf.Mesh(graph, "mesh_test")
     _ = mtf.import_tf_tensor(mesh,
                              tf_tensor=tf.constant(0.),
                              shape=mtf.Shape([]))
     self.assertLen(graph.operations, 1)
     self.assertEmpty(graph.trainable_variables)
     self.assertEmpty(graph.all_variables)
     _ = mtf.get_variable(mesh, "variable_0", mtf.Shape([]), trainable=True)
     self.assertLen(graph.operations, 2)
     self.assertLen(graph.trainable_variables, 1)
     self.assertLen(graph.all_variables, 1)
     _ = mtf.get_variable(mesh,
                          "variable_1",
                          mtf.Shape([]),
                          trainable=False)
     self.assertLen(graph.operations, 3)
     self.assertLen(graph.trainable_variables, 1)
     self.assertLen(graph.all_variables, 2)
示例#21
0
def CreateMeshes(inputs, labels, num_nodes, num_gpus, batch_size):
    graph = mtf.Graph()
    meshes = []
    mesh_to_impl = {}

    assert num_gpus % num_nodes == 0
    assert num_gpus % 2 == 0
    gpus_per_node = num_gpus // num_nodes
    devices = utils.GetDeviceList(num_gpus, gpus_per_node)

    mesh = mtf.Mesh(graph, f'mesh0')
    meshes.append(mesh)
    mesh_to_impl[mesh] = utils.GetMeshImpl([num_gpus//2],
            devices=devices[:num_gpus//2], gpus_per_node=gpus_per_node)

    mesh = mtf.Mesh(graph, f'mesh1')
    meshes.append(mesh)
    mesh_to_impl[mesh] = utils.GetMeshImpl([num_gpus//2],
            devices=devices[num_gpus//2:], gpus_per_node=gpus_per_node)

    mesh = mtf.Mesh(graph, f'mesh2')
    meshes.append(mesh)
    mesh_to_impl[mesh] = utils.GetMeshImpl([num_gpus],
            devices=utils.FlattenList(utils.TransposeLists(
                [devices[:num_gpus//2], devices[num_gpus//2:]])),
            gpus_per_node=gpus_per_node)

    assert len(inputs.shape) == 2
    assert inputs.shape == labels.shape

    shape = utils.ConvertToShape([('axis0', batch_size),
        inputs.shape.as_list()[1]])
    mtf_inputs = mtf.import_tf_tensor(meshes[2], inputs, shape)
    shape = shape.rename_dimension('axis0', utils.RandName())
    mtf_labels = mtf.import_tf_tensor(meshes[2], labels, shape)

    return graph, meshes, mesh_to_impl, mtf_inputs, mtf_labels
示例#22
0
    def testSeparableConv1d(self, random_normal_initializer_mock):
        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")

        depth_dim = mtf.Dimension("depth", 2)
        length_dim = mtf.Dimension("length", 4)
        output_dim = mtf.Dimension("output", 2)

        x = tf.constant([[1, 0], [0, 1], [1, 1], [2, 1]], dtype=tf.float32)
        mtf_x = mtf.import_tf_tensor(mesh,
                                     x,
                                     shape=mtf.Shape([length_dim, depth_dim]))

        initializer_mock = mock.MagicMock()
        random_normal_initializer_mock.return_value = initializer_mock
        initializer_mock.side_effect = initialize_by_shape({
            (2, ): [1, 2],
            (2, 2): [[1, 0], [1, -1]],
        })

        mtf_output = mtf.layers.separable_conv1d(mtf_x,
                                                 output_dim,
                                                 min_relative_pos=-1,
                                                 max_relative_pos=1,
                                                 use_bias=True)

        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_output = lowering.export_to_tf_tensor(mtf_output)

        self.evaluate(tf.global_variables_initializer())
        self.evaluate(lowering.copy_masters_to_slices())
        actual = self.evaluate(actual_output)

        self.assertAllClose(actual, [[3, -2], [6, -4], [9, -6], [7, -4]])
示例#23
0
def model_fn(features, labels, mode, params):
  """A model is called by TpuEstimator."""
  del labels
  del features

  mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
  layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)

  ctx = params['context']
  num_hosts = ctx.num_hosts
  host_placement_fn = ctx.tpu_host_placement_function
  device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
  tf.logging.info('device_list = %s' % device_list,)

  mesh_devices = [''] * mesh_shape.size
  mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
      mesh_shape, layout_rules, mesh_devices, ctx.device_assignment)

  graph = mtf.Graph()
  mesh = mtf.Mesh(graph, "fft_mesh")

  with mtf.utils.outside_all_rewrites():
    field = nbody_model(mesh)
    batch_dim, x_dim, y_dim, z_dim = field.shape
    x_dim_nosplit = mtf.Dimension("nx_nosplit", FLAGS.cube_size)
    y_dim_nosplit = mtf.Dimension("ny_nosplit", FLAGS.cube_size)

    # Until we implement distributed outputs, we only return one example
    field_slice, _ = mtf.split(field, batch_dim, [1, FLAGS.batch_size-1])
    field_slice = mtf.reshape(field_slice, [mtf.Dimension("bs", 1), x_dim_nosplit, y_dim_nosplit, z_dim])
    #field_slice = field

  lowering = mtf.Lowering(graph, {mesh: mesh_impl})
  tf_field = tf.to_float(lowering.export_to_tf_tensor(field_slice))

  with mtf.utils.outside_all_rewrites():
    return tpu_estimator.TPUEstimatorSpec(mode, predictions={'field': tf_field})
示例#24
0
    def testConv1d(self):
        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")

        filter_size = 3
        depth_dim = mtf.Dimension("depth", 2)
        length_dim = mtf.Dimension("length", 4)
        output_dim = mtf.Dimension("output", 2)

        x = tf.constant([[1, 0], [0, 1], [1, 1], [2, 1]], dtype=tf.float32)
        mtf_x = mtf.import_tf_tensor(mesh,
                                     x,
                                     shape=mtf.Shape([length_dim, depth_dim]))

        initializer_mock = mock.MagicMock()
        initializer_mock.side_effect = initialize_by_shape({
            (1, 3, 2, 2): [[[[1, -1], [0, 0]], [[2, -2], [-1, 1]],
                            [[3, -3], [-2, 2]]]],
        })

        mtf_output = mtf.layers.conv1d(mtf_x,
                                       output_dim=output_dim,
                                       filter_size=filter_size,
                                       filter_initializer=initializer_mock)

        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_output = lowering.export_to_tf_tensor(mtf_output)

        self.evaluate(tf.global_variables_initializer())
        self.evaluate(lowering.copy_masters_to_slices())
        actual = self.evaluate(actual_output)

        self.assertAllClose(actual, [[0, 0], [1, -1], [5, -5], [4, -4]])
示例#25
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        # MTF setup.
        graph = mtf.Graph()
        mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
        layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)

        ctx = params["context"]
        num_hosts = ctx.num_hosts
        host_placement_fn = ctx.tpu_host_placement_function
        device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
        tf.logging.info("device_list = %s" % device_list, )
        replica_cache_size = 300 * 1000000  # 300M per replica
        # Worker 0 caches all the TPU binaries.
        worker0_mem = replica_cache_size * ctx.num_replicas
        devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
        var_placer = mtf.utils.BalancedVariablePlacer(device_list,
                                                      devices_memeory_usage)
        mesh_devices = [""] * mesh_shape.size
        physical_shape = list(ctx.device_assignment.topology.mesh_shape)
        logical_to_physical = mtf.simd_mesh_impl.auto_logical_to_physical_tpu(
            mesh_shape.to_integer_list, physical_shape)
        mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
            mesh_shape,
            layout_rules,
            mesh_devices,
            ctx.device_assignment,
            logical_to_physical=logical_to_physical)
        mesh = mtf.Mesh(graph, "bert_mesh", var_placer)

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        masked_lm_positions = features["masked_lm_positions"]
        masked_lm_ids = features["masked_lm_ids"]
        masked_lm_weights = features["masked_lm_weights"]
        next_sentence_labels = tf.squeeze(features["next_sentence_labels"], 1)

        batch_size = input_ids.get_shape()[0].value
        batch_dim = mtf.Dimension("batch", batch_size)

        seq_length = input_ids.get_shape()[1].value
        seq_dim = mtf.Dimension("seq", seq_length)
        max_predictions_per_seq = masked_lm_positions.get_shape()[1].value
        max_predictions_per_seq_dim = mtf.Dimension("max_pred_seq",
                                                    max_predictions_per_seq)

        mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids,
                                             [batch_dim, seq_dim])
        mtf_input_mask = mtf.import_tf_tensor(mesh, input_mask,
                                              [batch_dim, seq_dim])
        mtf_segment_ids = mtf.import_tf_tensor(mesh, segment_ids,
                                               [batch_dim, seq_dim])
        mtf_masked_lm_positions = mtf.import_tf_tensor(
            mesh, masked_lm_positions,
            [batch_dim, max_predictions_per_seq_dim])
        mtf_masked_lm_ids = mtf.import_tf_tensor(
            mesh, masked_lm_ids, [batch_dim, max_predictions_per_seq_dim])

        mtf_masked_lm_weights = mtf.import_tf_tensor(
            mesh, masked_lm_weights, [batch_dim, max_predictions_per_seq_dim])
        mtf_next_sentence_labels = mtf.import_tf_tensor(
            mesh, next_sentence_labels, [batch_dim])

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        model = bert_lib.BertModel(config=bert_config,
                                   is_training=is_training,
                                   input_ids=mtf_input_ids,
                                   input_mask=mtf_input_mask,
                                   token_type_ids=mtf_segment_ids,
                                   layout=layout_rules,
                                   mesh_shape=mesh_shape)

        (masked_lm_loss, masked_lm_example_loss,
         masked_lm_logits) = model.get_masked_lm_output(
             mtf_masked_lm_positions, mtf_masked_lm_ids, mtf_masked_lm_weights)

        (next_sentence_loss, next_sentence_example_loss, next_sentence_logits
         ) = model.get_next_sentence_output(mtf_next_sentence_labels)

        extra_loss = model.get_extra_loss()

        total_loss = masked_lm_loss + next_sentence_loss
        total_loss = mtf.anonymize(total_loss)
        masked_lm_example_loss = mtf.anonymize(masked_lm_example_loss)
        masked_lm_logits = mtf.anonymize(masked_lm_logits)
        next_sentence_example_loss = mtf.anonymize(next_sentence_example_loss)
        next_sentence_logits = mtf.anonymize(next_sentence_logits)

        # TRAIN mode
        if mode == tf.estimator.ModeKeys.TRAIN:
            _, update_ops = optimization_lib.create_optimizer(
                total_loss + extra_loss,
                learning_rate,
                num_train_steps,
                num_warmup_steps,
                optimizer=FLAGS.optimizer,
                clip_gradients=FLAGS.clip_gradients)

        lowering = mtf.Lowering(graph, {mesh: mesh_impl})

        tf_loss = tf.to_float(lowering.export_to_tf_tensor(total_loss))

        if mode == tf.estimator.ModeKeys.TRAIN:
            global_step = tf.train.get_global_step()
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
            train_op = tf.group(tf_update_ops)
        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(masked_lm_example_loss, masked_lm_logits,
                          masked_lm_ids, masked_lm_weights,
                          next_sentence_example_loss, next_sentence_logits,
                          next_sentence_labels):
                """Computes the loss and accuracy of the model."""
                masked_lm_logits = tf.reshape(masked_lm_logits,
                                              [-1, masked_lm_logits.shape[-1]])
                masked_lm_predictions = tf.argmax(masked_lm_logits,
                                                  axis=-1,
                                                  output_type=tf.int32)
                masked_lm_example_loss = tf.reshape(masked_lm_example_loss,
                                                    [-1])
                masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
                masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
                masked_lm_accuracy = tf.metrics.accuracy(
                    labels=masked_lm_ids,
                    predictions=masked_lm_predictions,
                    weights=masked_lm_weights)
                masked_lm_mean_loss = tf.metrics.mean(
                    values=masked_lm_example_loss, weights=masked_lm_weights)

                next_sentence_logits = tf.reshape(
                    next_sentence_logits, [-1, next_sentence_logits.shape[-1]])
                next_sentence_predictions = tf.argmax(next_sentence_logits,
                                                      axis=-1,
                                                      output_type=tf.int32)
                next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
                next_sentence_accuracy = tf.metrics.accuracy(
                    labels=next_sentence_labels,
                    predictions=next_sentence_predictions)
                next_sentence_mean_loss = tf.metrics.mean(
                    values=next_sentence_example_loss)

                return {
                    "masked_lm_accuracy": masked_lm_accuracy,
                    "masked_lm_loss": masked_lm_mean_loss,
                    "next_sentence_accuracy": next_sentence_accuracy,
                    "next_sentence_loss": next_sentence_mean_loss,
                }

            eval_metrics = (metric_fn, [
                lowering.export_to_tf_tensor(masked_lm_example_loss),
                lowering.export_to_tf_tensor(masked_lm_logits), masked_lm_ids,
                masked_lm_weights,
                lowering.export_to_tf_tensor(next_sentence_example_loss),
                lowering.export_to_tf_tensor(next_sentence_logits),
                next_sentence_labels
            ])

        with mtf.utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            if mode == tf.estimator.ModeKeys.TRAIN:
                saver = tf.train.Saver(tf.global_variables(),
                                       sharded=True,
                                       max_to_keep=10,
                                       keep_checkpoint_every_n_hours=2,
                                       defer_build=False,
                                       save_relative_paths=True)
                tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
                saver_listener = mtf.MtfCheckpointSaverListener(lowering)
                saver_hook = tf.train.CheckpointSaverHook(
                    FLAGS.output_dir,
                    save_steps=1000,
                    saver=saver,
                    listeners=[saver_listener])

                return tf.estimator.tpu.TPUEstimatorSpec(
                    tf.estimator.ModeKeys.TRAIN,
                    loss=tf_loss,
                    train_op=train_op,
                    training_hooks=[restore_hook, saver_hook])
            elif mode == tf.estimator.ModeKeys.EVAL:
                return tf.estimator.tpu.TPUEstimatorSpec(
                    tf.estimator.ModeKeys.EVAL,
                    evaluation_hooks=[restore_hook],
                    loss=tf_loss,
                    eval_metrics=eval_metrics)
示例#26
0
    def estimator_model_fn(cls,
                           hparams,
                           features,
                           labels,
                           mode,
                           config=None,
                           params=None,
                           decode_hparams=None):
        hparams = copy.deepcopy(hparams)
        use_tpu = params and params.get("use_tpu", False)
        hparams.use_tpu = use_tpu
        # merge decode_hparams into hparams if present
        if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None:
            for k, v in six.iteritems(decode_hparams.values()):
                if hasattr(hparams, k) and getattr(hparams, k) != v:
                    tf.logging.warning(
                        "Overriding hparams.%s with %s from decode_hparams" %
                        (k, v))
                setattr(hparams, k, v)

        # Instantiate model
        data_parallelism = None
        if not use_tpu and config:
            data_parallelism = config.data_parallelism
        model = cls(hparams,
                    mode,
                    data_parallelism=data_parallelism,
                    decode_hparams=decode_hparams)

        global_step = tf.train.get_global_step()

        mesh_shape = mtf.convert_to_shape(hparams.mesh_shape)
        layout_rules = mtf.convert_to_layout_rules(hparams.layout)
        if use_tpu:
            ctx = params["context"]
            num_hosts = ctx.num_hosts
            host_placement_fn = ctx.tpu_host_placement_function
            device_list = [
                host_placement_fn(host_id=t) for t in range(num_hosts)
            ]
            # TODO(ylc): Better estimation of replica cache size?
            replica_cache_size = 300 * 1000000  # 300M per replica
            # Worker 0 caches all the TPU binaries.
            worker0_mem = replica_cache_size * ctx.num_replicas
            devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
            var_placer = mtf.utils.BalancedVariablePlacer(
                device_list, devices_memeory_usage)
            mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape, layout_rules, mesh_devices, ctx.device_assignment)
        else:
            var_placer = None
            if len(data_parallelism.ps_devices) == 1:
                mesh_devices = [""] * mesh_shape.size
            else:
                assert len(data_parallelism.ps_devices) == mesh_shape.size
                mesh_devices = data_parallelism.ps_devices
            mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, layout_rules, mesh_devices)

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh", var_placer)
        # PREDICT mode
        if mode == tf.estimator.ModeKeys.PREDICT:
            return model.estimator_spec_predict(features, mesh, mesh_impl,
                                                use_tpu)

        logits, loss = model.mtf_model_fn(features, mesh)
        if use_tpu and logits is not None:
            logits = mtf.anonymize(logits)

        # TRAIN mode
        if mode == tf.estimator.ModeKeys.TRAIN:
            var_grads = mtf.gradients(
                [loss], [v.outputs[0] for v in graph.trainable_variables])
            lr = learning_rate.learning_rate_schedule(hparams)
            tf.summary.scalar("learning_rate", lr)
            mtf_lr = mtf.import_tf_tensor(
                mesh, tf.convert_to_tensor(lr, dtype=tf.float32),
                mtf.Shape([]))
            optimizer = mtf.optimize.make_optimizer(hparams, mtf_lr)
            update_ops = []
            for grad, var in zip(var_grads, graph.trainable_variables):
                update_ops.extend(optimizer.apply_grad(grad, var))

        lowering = mtf.Lowering(graph, {mesh: mesh_impl})

        tf_loss = lowering.export_to_tf_tensor(loss)
        tf_loss = tf.to_float(tf_loss)
        if logits and mode != tf.estimator.ModeKeys.TRAIN:
            tf_logits = lowering.export_to_tf_tensor(logits)

        if mode == tf.estimator.ModeKeys.TRAIN:
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            # tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
            train_op = tf.group(tf_update_ops)

        with mtf.utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=10,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                hparams.model_dir,
                save_steps=1000,
                saver=saver,
                listeners=[saver_listener])

        # EVAL mode
        if mode == tf.estimator.ModeKeys.EVAL:
            tf_logits = lowering.export_to_tf_tensor(logits)
            return model.estimator_spec_eval(features, tf_logits, labels,
                                             tf_loss, restore_hook, use_tpu)

        if use_tpu:
            # TPU host call. Important: need to be called before remove_summaries()
            if hparams.tpu_enable_host_call:
                host_call = t2t_model.create_host_call(hparams.model_dir)
            else:
                host_call = None

            t2t_model.remove_summaries()
            return tpu_estimator.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                train_op=train_op,
                host_call=host_call,
                training_hooks=[restore_hook, saver_hook])
        else:
            return tf.estimator.EstimatorSpec(
                tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                train_op=train_op,
                training_chief_hooks=[restore_hook, saver_hook])
示例#27
0
def model_fn(features, labels, mode, params):
    """The model_fn argument for creating an Estimator."""
    tf.logging.info("features = %s labels = %s mode = %s params=%s" %
                    (features, labels, mode, params))
    global_step = tf.train.get_global_step()
    graph = mtf.Graph()
    # wrapped graph named "my_mesh"
    mesh = mtf.Mesh(graph, "my_mesh")
    logits, loss = mnist_model(features, labels, mesh)
    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)
    mesh_size = mesh_shape.size
    print("mesh_shape.size = ", mesh_shape.size)
    mesh_devices = [""] * mesh_size
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        mesh_shape, layout_rules, mesh_devices)

    if mode == tf.estimator.ModeKeys.TRAIN:
        var_grads = mtf.gradients(
            [loss], [v.outputs[0] for v in graph.trainable_variables])
        optimizer = mtf.optimize.AdafactorOptimizer()
        update_ops = optimizer.apply_grads(var_grads,
                                           graph.trainable_variables)

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    restore_hook = mtf.MtfRestoreHook(lowering)

    tf_logits = lowering.export_to_tf_tensor(logits)
    if mode != tf.estimator.ModeKeys.PREDICT:
        tf_loss = lowering.export_to_tf_tensor(loss)
        tf.summary.scalar("loss", tf_loss)

    if mode == tf.estimator.ModeKeys.TRAIN:
        tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
        tf_update_ops.append(tf.assign_add(global_step, 1))
        train_op = tf.group(tf_update_ops)
        saver = tf.train.Saver(tf.global_variables(),
                               sharded=True,
                               max_to_keep=10,
                               keep_checkpoint_every_n_hours=2,
                               defer_build=False,
                               save_relative_paths=True)
        tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
        saver_listener = mtf.MtfCheckpointSaverListener(lowering)
        saver_hook = tf.train.CheckpointSaverHook(FLAGS.model_dir,
                                                  save_steps=1000,
                                                  saver=saver,
                                                  listeners=[saver_listener])

        accuracy = tf.metrics.accuracy(labels=labels,
                                       predictions=tf.argmax(tf_logits,
                                                             axis=1))

        # Name tensors to be logged with LoggingTensorHook.
        tf.identity(tf_loss, "cross_entropy")
        tf.identity(accuracy[1], name="train_accuracy")

        # Save accuracy scalar to Tensorboard output.
        tf.summary.scalar("train_accuracy", accuracy[1])

        # restore_hook must come before saver_hook
        return tf.estimator.EstimatorSpec(
            tf.estimator.ModeKeys.TRAIN,
            loss=tf_loss,
            train_op=train_op,
            training_chief_hooks=[restore_hook, saver_hook])

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            "classes": tf.argmax(tf_logits, axis=1),
            "probabilities": tf.nn.softmax(tf_logits),
        }
        return tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.PREDICT,
            predictions=predictions,
            prediction_hooks=[restore_hook],
            export_outputs={
                "classify": tf.estimator.export.PredictOutput(predictions)
            })
    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.EVAL,
            loss=tf_loss,
            evaluation_hooks=[restore_hook],
            eval_metric_ops={
                "accuracy":
                tf.metrics.accuracy(labels=labels,
                                    predictions=tf.argmax(tf_logits, axis=1)),
            })
示例#28
0
def model_fn(features, labels, mode, params):
  """A model is called by TpuEstimator."""
  del labels
  global_step = tf.train.get_global_step()
  graph = mtf.Graph()
  mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
  layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)
  if FLAGS.use_tpu:
    ctx = params['context']
    num_hosts = ctx.num_hosts
    host_placement_fn = ctx.tpu_host_placement_function
    device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
    tf.logging.info('device_list = %s' % device_list,)
    # TODO(ylc): Better estimation of replica cache size?
    replica_cache_size = 300 * 1000000  # 300M per replica
    # Worker 0 caches all the TPU binaries.
    worker0_mem = replica_cache_size * ctx.num_replicas
    devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
    var_placer = mtf.utils.BalancedVariablePlacer(device_list,
                                                  devices_memeory_usage)
    mesh_devices = [''] * mesh_shape.size
    mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
        mesh_shape, layout_rules, mesh_devices, ctx.device_assignment)
  else:
    var_placer = None
    mesh_devices = [''] * mesh_shape.size
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        mesh_shape, layout_rules, mesh_devices)
  mesh = mtf.Mesh(graph, 'my_mesh', var_placer)

  with mtf.utils.outside_all_rewrites():
    logits, loss = toy_model(features, mesh)

  # TRAIN mode
  if mode == tf.estimator.ModeKeys.TRAIN:
    var_grads = mtf.gradients([loss],
                              [v.outputs[0] for v in graph.trainable_variables])
    if FLAGS.optimizer == 'Adafactor':
      optimizer = mtf.optimize.AdafactorOptimizer()
    else:
      assert FLAGS.optimizer == 'SGD'
      optimizer = mtf.optimize.SgdOptimizer(lr=FLAGS.lr)
    update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables)
  else:
    # for now, we can only export fully-replicated tensors.
    fully_replicated_logits = mtf.anonymize(logits)

  lowering = mtf.Lowering(graph, {mesh: mesh_impl})

  tf_loss = tf.to_float(lowering.export_to_tf_tensor(loss))

  if mode == tf.estimator.ModeKeys.TRAIN:
    tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
    tf_update_ops.append(tf.assign_add(global_step, 1))
    tf.logging.info('tf_update_ops: {}'.format(tf_update_ops))
    train_op = tf.group(tf_update_ops)
  else:
    tf_logits = lowering.export_to_tf_tensor(fully_replicated_logits)

  with mtf.utils.outside_all_rewrites():
    # Copy master variables to slices. Must be called first.
    restore_hook = mtf.MtfRestoreHook(lowering)
    if mode == tf.estimator.ModeKeys.TRAIN:
      saver = tf.train.Saver(
          tf.global_variables(),
          sharded=True,
          max_to_keep=10,
          keep_checkpoint_every_n_hours=2,
          defer_build=False,
          save_relative_paths=True)
      tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
      saver_listener = mtf.MtfCheckpointSaverListener(lowering)
      saver_hook = tf.train.CheckpointSaverHook(
          FLAGS.model_dir,
          save_steps=1000,
          saver=saver,
          listeners=[saver_listener])

      return tpu_estimator.TPUEstimatorSpec(
          tf.estimator.ModeKeys.TRAIN,
          loss=tf_loss,
          train_op=train_op,
          training_hooks=[restore_hook, saver_hook])
    elif mode == tf.estimator.ModeKeys.EVAL:

      def metric_fn(tf_logits):
        mean_logits = tf.metrics.mean(tf_logits)
        return {'mean_logits': mean_logits}

      eval_metrics = (metric_fn, [tf_logits])

      return tpu_estimator.TPUEstimatorSpec(
          tf.estimator.ModeKeys.EVAL,
          evaluation_hooks=[restore_hook],
          loss=tf_loss,
          eval_metrics=eval_metrics)
示例#29
0
def model_fn(features, labels, mode, params):
    # Get global step
    global_step = tf.train.get_global_step()

    # Construct mtf graph + mesh from params
    graph = mtf.Graph()
    mesh_shape = mtf.convert_to_shape(params["mesh_shape"])
    layout_rules = mtf.convert_to_layout_rules(params["layout"])

    # Mesh setup
    if params["use_tpu"]:
        var_placer, mesh_impl = simd_mesh_setup(params, mesh_shape,
                                                layout_rules)
    else:
        var_placer = None
        gpu_ids = params["gpu_ids"]
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
            mesh_shape, layout_rules, gpu_ids)

    # Trainable variable precision
    # Store to checkpoints in master type, train in slice type, compute in activation type
    if params["precision"] == "bfloat16":
        variable_dtype = mtf.VariableDType(master_dtype=tf.bfloat16,
                                           slice_dtype=tf.float32,
                                           activation_dtype=tf.bfloat16)
    else:
        variable_dtype = mtf.VariableDType(master_dtype=tf.float32,
                                           slice_dtype=tf.float32,
                                           activation_dtype=tf.float32)

    # Build mtf mesh object
    mesh = mtf.Mesh(graph, "my_mesh", var_placer)

    # Build mtf_features & seq length dict for getting number of microbatches
    # We need to pack inputs into a dict to pass into serialize_training_step
    features_dict = {"inputs": features, "labels": labels}
    sequence_length_dict = {
        "inputs": params["n_ctx"],
        "labels": params["n_ctx"]
    }

    params = add_mode_to_params(params, mode)
    batch_size = get_batch_size(params)

    batch_dim = mtf.Dimension("batch", batch_size)
    batch_dims = [batch_dim]
    feature_length = sequence_length_dict["inputs"]
    length_dim = mtf.Dimension("sequence", feature_length)

    mtf_features = {}
    for key, x in features_dict.items():
        if x is not None:
            feature_shape = mtf.Shape(batch_dims + [length_dim])
            if type(features_dict[key]) == dict:
                features_dict[key] = features_dict[key]["feature"]
            x = tf.cast(features_dict[key], tf.int32)
            x = tf.reshape(x, feature_shape.to_integer_list)
            mtf_features[key] = mtf.import_fully_replicated(mesh,
                                                            x,
                                                            feature_shape,
                                                            name=key)

    # Instantiate dict for dimensions, bias, etc that can be calculated here once then passed into model
    other_features = {}
    memory_length_dim = mtf.Dimension("memory_length", length_dim.size)

    attn_bias = biasmask_attn_weights(
        mesh, length_dim, memory_length_dim,
        variable_dtype) if params["causal"] else None

    # Add attn_bias into mtf_features
    other_features["attn_bias"] = attn_bias

    # Define other Dimensions that we'll need inside the model
    embd_dim = mtf.Dimension("embd", params["n_embd"])
    vocab_dim = mtf.Dimension("vocab", params["n_vocab"])
    # We need this because gathering when both the args have the same dimension in them breaks things
    # This dim is specifically for the weights
    # This prevents the "Einsum has lhs dimension without corresponding rhs or output dimension." error
    embed_sequence_dim = mtf.Dimension("embed_sequence", params["n_ctx"])

    other_features["embd_dim"] = embd_dim
    other_features["vocab_dim"] = vocab_dim
    other_features["embed_sequence_dim"] = embed_sequence_dim
    other_features["memory_length_dim"] = memory_length_dim

    if mode == tf.estimator.ModeKeys.PREDICT:
        # Set up the model for prediction
        inputs = mtf_features["inputs"]
        if params["remove_partial_sequences"] is None:
            params["remove_partial_sequences"] = False

        export = params.get("export", False)

        if not export:
            mtf_samples = sample_autoregressive(
                inputs,
                other_features=other_features,
                params=params,
                variable_dtype=variable_dtype,
                remove_partial_sequences=params["remove_partial_sequences"],
                stop_at_token=params["eos_id"],
                sampling_use_entmax=params['sampling_use_entmax'])

        else:
            with mtf.utils.outside_all_rewrites():
                with tf.variable_scope('gpt2'):
                    mtf_samples, loss, loss_batch = gpt2.model(
                        mtf_features,
                        other_features,
                        params,
                        mesh,
                        variable_dtype=variable_dtype,
                        context=None)

        mtf_samples = mtf.anonymize(mtf_samples)
        inputs = mtf.anonymize(inputs)
        lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True)
        inputs = lowering.export_to_tf_tensor(inputs)
        outputs = lowering.export_to_tf_tensor(mtf_samples)
        predictions = {"inputs": inputs, "outputs": outputs}

        def scaffold_fn():
            return tf.train.Scaffold(
                local_init_op=tf.group(
                    tf.train.Scaffold.default_local_init_op(),
                    lowering.copy_masters_to_slices(),
                    name="mtf_local_init_op"),
                ready_op=tf.concat([
                    tf.report_uninitialized_variables(),
                    resources.report_uninitialized_resources()
                ],
                                   axis=0,
                                   name="mtf_ready_op"))

        return tpu_estimator.TPUEstimatorSpec(
            mode=tf.estimator.ModeKeys.PREDICT,
            predictions=predictions,
            scaffold_fn=scaffold_fn,
            prediction_hooks=[mtf.MtfRestoreHook(lowering)])

    # We're not predicting, so we better be training or evaluating
    assert (mode == tf.estimator.ModeKeys.TRAIN
            or mode == tf.estimator.ModeKeys.EVAL)

    if mode == tf.estimator.ModeKeys.TRAIN:
        # Gets number of microbatches per batch for serialized training
        # if param tokens_per_mb_per_replica = None, this defaults to 1 and no microbatching is performed
        num_microbatches = int(
            mtf_transformer.utils.serialize_num_microbatches(
                batch_dim=batch_dim,
                sequence_length=sequence_length_dict,
                mesh_shape=mesh_shape,
                layout_rules=layout_rules,
                tokens_per_microbatch_per_replica=params[
                    "tokens_per_mb_per_replica"]))
    else:
        num_microbatches = 1

    params[
        "num_microbatches"] = num_microbatches  # Add num microbatches to params

    if num_microbatches > 1:

        # For serialize_training_step we need to modify the model to output results in a dict
        def serialized_fn(mtf_features):
            if params["model"] == "GPT":
                with tf.variable_scope('gpt2'):
                    logits, loss, loss_batch = gpt2.model(
                        mtf_features,
                        other_features,
                        params,
                        mesh,
                        variable_dtype=variable_dtype)
                return {
                    "logits": logits,
                    "loss": loss,
                    "loss_batch": loss_batch
                }
            else:
                raise Exception(
                    f"'{params['model']}' is not a valid model - please select from [GPT]"
                )

        # Serialize the training step - Gradients are accumulated locally and reduced once.
        var_grads, output_dict = mtf.serialize_training_step(
            mtf_features, serialized_fn, batch_dim, num_microbatches)
        loss = output_dict["loss"]
        loss_batch = output_dict["loss_batch"]
        logits = output_dict["logits"]
    else:
        # If we're not splitting into microbatches, return logits & loss as is
        if params["model"] == "GPT":
            with mtf.utils.outside_all_rewrites():
                with tf.variable_scope('gpt2'):
                    logits, loss, loss_batch = gpt2.model(
                        mtf_features,
                        other_features,
                        params,
                        mesh,
                        variable_dtype=variable_dtype,
                        context=None)
        else:
            raise Exception(
                f"'{params['model']}' is not a valid model - please select from [GPT]"
            )

    # Auto layout generation
    if params["auto_layout"]:
        auto_layout(graph, mesh_shape, logits, loss)
    if params["auto_layout_and_mesh_shape"]:
        auto_layout_and_mesh_shape(graph, params["num_cores"], logits, loss)

    if mode == tf.estimator.ModeKeys.TRAIN:
        # In TRAIN mode, get optimizer
        if params["num_microbatches"] > 1:
            # If we are splitting the batch into microbatches, var grads are created in the serialize_training_step fn
            # So we pass them in here
            _, update_ops, var_grads = get_optimizer(
                mesh,
                loss,
                params,
                variable_dtype=variable_dtype,
                inp_var_grads=var_grads)
        else:
            # Otherwise, they are created in the get_optimizer fn, so we leave inp_var_grads blank
            _, update_ops, var_grads = get_optimizer(
                mesh, loss, params, variable_dtype=variable_dtype)
        # Log summaries to tensorboard
        mtf.scalar_summary("loss", loss)
        # Log gradients if in params
        if params["log_grads"] not in [None, False]:
            for g in var_grads:
                grad_norm = mtf.sqrt(mtf.reduce_sum(mtf.square(g)))
                mtf.scalar_summary("grads/norm" + g.name[:-2], grad_norm)
    else:
        # For now, we can only export fully-replicated tensors.
        # This has to be done before lowering or they will not be included in the graph
        mean_logits = mtf.reduce_mean(logits, reduced_dim=vocab_dim)
        max_logits = mtf.argmax(logits, vocab_dim)
        del logits
        fully_replicated_mean_logits = mtf.anonymize(mean_logits)
        fully_replicated_max_logits = mtf.anonymize(max_logits)
        fully_replicated_loss_batch = mtf.anonymize(loss_batch)

    # Gets & prints info about no. trainable vars in the model & dimension names
    get_graph_info(graph)

    # 'lowers' mtf tensors into a tf graph - this enables us to export results as tf tensors
    lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True)
    tf_loss = lowering.export_to_tf_tensor(loss)
    tf_loss = tf.cast(tf_loss, tf.float32)

    if mode == tf.estimator.ModeKeys.TRAIN:
        # Use our patched version until mtf updates theirs
        host_call = create_host_call(params['model_path'])
        mtf.utils.remove_summaries()

        # Creates train_op
        tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
        tf_update_ops.append(tf.assign_add(
            global_step, 1))  # Need to manually increment global_step
        tf.logging.info(f"tf_update_ops: {tf_update_ops}")
        train_op = tf.group(tf_update_ops)
    else:
        tf_mean_logits = lowering.export_to_tf_tensor(
            fully_replicated_mean_logits)
        tf_max_logits = lowering.export_to_tf_tensor(
            fully_replicated_max_logits)
        tf_loss_batch = tf.to_float(
            lowering.export_to_tf_tensor(fully_replicated_loss_batch))

    with mtf.utils.outside_all_rewrites():
        # Copy master variables to slices. Must be called first.
        restore_hook = mtf.MtfRestoreHook(lowering)
        if mode == tf.estimator.ModeKeys.TRAIN:
            # Set up the checkpoint server and return the TPUEstimatorSpec
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=10,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                params["model_path"],
                save_steps=params["steps_per_checkpoint"],
                saver=saver,
                listeners=[saver_listener])

            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                host_call=host_call,
                train_op=train_op,
                training_hooks=[restore_hook, saver_hook])

        elif mode == tf.estimator.ModeKeys.EVAL:
            # Evaluation metrics
            def _perplexity(loss):
                perplexity = tf.exp(loss)
                return tf.metrics.mean(perplexity)

            def _bits_per_byte(loss):
                bpb = loss * (0.29335 / math.log(2))
                return tf.metrics.mean(bpb)

            def _metric_fn(tf_mean_logits, tf_loss_batch):
                mean_logits = tf.metrics.mean(tf_mean_logits)
                loss = tf.reduce_mean(tf_loss_batch)
                perp = _perplexity(loss)
                bpb = _bits_per_byte(loss)
                return {
                    "mean_logits": mean_logits,
                    "perplexity": perp,
                    "bits per byte": bpb
                }

            def _lambada_metric_fn(labels, tf_max_logits, tf_loss_batch):
                eos_token = params["eos_id"]
                answer_positions = tf.where(
                    tf.math.not_equal(labels, eos_token))

                correct_answers = tf.gather_nd(
                    tf.math.equal(tf_max_logits, labels), answer_positions)
                accuracy = tf.metrics.mean(tf.cast(correct_answers,
                                                   tf.float32))

                # I guess tf_loss_batch has z_loss and maybe other stuff added to it
                # so maybe this should be calculated separately in the future
                answer_loss = tf.gather_nd(tf_loss_batch, answer_positions)
                log_perplexity = tf.metrics.mean(answer_loss)

                return {
                    "lambada_acc": accuracy,
                    "lambada_log_ppl": log_perplexity
                }

            eval_task = params["eval_task"]
            if eval_task == "lambada":
                eval_metrics = (_lambada_metric_fn,
                                [labels, tf_max_logits, tf_loss_batch])
            else:
                eval_metrics = (_metric_fn, [tf_mean_logits, tf_loss_batch])

            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.EVAL,
                evaluation_hooks=[restore_hook],
                loss=tf_loss,
                eval_metrics=eval_metrics)
示例#30
0
    def testPool(self, pooling_method):
        batch = 2
        depth = 3
        height = 4
        width = 6
        channels = 3
        tf.random.set_random_seed(1234)
        inputs = tf.random_normal([batch, depth, height, width, channels])

        stride_d = 3
        stride_h = 2
        stride_w = 3

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        batch_dim = mtf.Dimension("batch", batch)
        depth_dim = mtf.Dimension("depth", depth)
        height_dim = mtf.Dimension("height", height)
        width_dim = mtf.Dimension("width", width)
        channels_dim = mtf.Dimension("channels", channels)

        mtf_inputs = mtf.import_tf_tensor(mesh,
                                          inputs,
                                          shape=mtf.Shape([
                                              batch_dim, depth_dim, height_dim,
                                              width_dim, channels_dim
                                          ]))

        if pooling_method == "MAX_2D":
            mtf_outputs = mtf.layers.max_pool2d(mtf_inputs,
                                                ksize=(stride_h, stride_w))
            inputs = tf.reshape(inputs,
                                [batch * depth, height, width, channels])
            expected_outputs = tf.keras.layers.MaxPooling2D(
                (stride_h, stride_w))(inputs)
            expected_outputs = tf.reshape(expected_outputs, [
                batch, depth,
                int(height / stride_h),
                int(width / stride_w), channels
            ])

        elif pooling_method == "AVG_2D":
            mtf_outputs = mtf.layers.avg_pool2d(mtf_inputs,
                                                ksize=(stride_h, stride_w))
            inputs = tf.reshape(inputs,
                                [batch * depth, height, width, channels])
            expected_outputs = tf.keras.layers.AveragePooling2D(
                (stride_h, stride_w))(inputs)
            expected_outputs = tf.reshape(expected_outputs, [
                batch, depth,
                int(height / stride_h),
                int(width / stride_w), channels
            ])

        elif pooling_method == "MAX_3D":
            mtf_outputs = mtf.layers.max_pool3d(
                mtf_inputs, ksize=[stride_d, stride_h, stride_w])
            expected_outputs = tf.keras.layers.MaxPooling3D(
                [stride_d, stride_h, stride_w])(inputs)

        elif pooling_method == "AVG_3D":
            mtf_outputs = mtf.layers.avg_pool3d(
                mtf_inputs, ksize=[stride_d, stride_h, stride_w])
            expected_outputs = tf.keras.layers.AveragePooling3D(
                [stride_d, stride_h, stride_w])(inputs)

        mtf_gradient = mtf.gradients([mtf_outputs], [mtf_inputs])[0]

        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)
        actual_gradient = lowering.export_to_tf_tensor(mtf_gradient)

        tf_group = lowering.copy_masters_to_slices()
        init = tf.global_variables_initializer()
        self.evaluate(init)
        self.evaluate(tf_group)
        actual, expected = self.evaluate([actual_outputs, expected_outputs])
        self.assertAllClose(actual, expected)

        actual = self.evaluate(actual_gradient)
        if pooling_method == "MAX_2D":
            expected_non_zeros = batch * depth * height * width * channels / (
                stride_h * stride_w)
            self.assertEqual(np.count_nonzero(actual), expected_non_zeros)

        elif pooling_method == "AVG_2D":
            expected = np.ones((batch, depth, height, width, channels),
                               dtype=np.float32) / stride_h / stride_w
            self.assertAllClose(actual, expected)

        elif pooling_method == "MAX_3D":
            expected_non_zeros = batch * depth * height * width * channels / (
                stride_d * stride_h * stride_w)
            self.assertEqual(np.count_nonzero(actual), expected_non_zeros)

        elif pooling_method == "AVG_3D":
            expected = np.ones(
                (batch, depth, height, width, channels),
                dtype=np.float32) / stride_d / stride_h / stride_w
            self.assertAllClose(actual, expected)