Example #1
0
 def computation_fn():
     graph = mtf.Graph()
     mesh = mtf.Mesh(graph, 'my_mesh')
     mesh_shape = mtf.convert_to_shape('all:2')
     layout = 'none:all'
     mesh_devices = [''] * mesh_shape.size
     mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
         mesh_shape, mtf.convert_to_layout_rules(layout),
         mesh_devices, device_assignment)
     hidden_dim = mtf.Dimension('hidden', 3)
     w = mtf.get_variable(mesh,
                          'w',
                          shape=[hidden_dim],
                          initializer=tf.constant_initializer(
                              [0.1, -0.2, -0.1]))
     x = mtf.constant(mesh, [0.4, 0.2, -0.5], [hidden_dim],
                      dtype=tf.float32)
     loss = mtf.reduce_mean(mtf.square(x - w))
     var_grads = mtf.gradients(
         [loss], [v.outputs[0] for v in graph.trainable_variables])
     optimizer = mtf_optimize.AdamWeightDecayOptimizer(
         learning_rate=0.2)
     update_ops = optimizer.apply_grads(var_grads,
                                        graph.trainable_variables)
     self.lowering = mtf.Lowering(graph, {mesh: mesh_impl})
     tf_update_ops = [
         self.lowering.lowered_operation(op) for op in update_ops
     ]
     return tf.group(tf_update_ops)
Example #2
0
  def __init__(self, sess, use_tpu, mesh_shape, layout_rules):
    super(MeshContext, self).__init__()
    self._use_tpu = use_tpu
    self._mesh_shape = mtf.convert_to_shape(mesh_shape)
    self._layout_rules = layout_rules

    self._d_assignment = None
    self._num_hosts = None
    self._num_cores = None

    self._cpu_devices, self._gpu_devices = self._list_cpu_gpu_devices(sess)

    if self._use_tpu:
      topology = sess.run(tpu.initialize_system())
      topo_object = tpu.Topology(serialized=topology)
      self._num_cores = int(np.prod(topo_object.mesh_shape))
      self._num_hosts = int(topo_object.num_tasks)
      num_cores_per_host = int(self._num_cores // self._num_hosts)
      assert num_cores_per_host == int(topo_object.num_tpus_per_task)

      # Get a device_assignment object for mtf.
      self._d_assignment = device_assignment.device_assignment(
          topology,
          computation_shape=[1,] * mtf.utils.topology_rank(topology),
          num_replicas=self._num_cores)

      self._mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
          self._mesh_shape, self._layout_rules, None, self._d_assignment)
    else:
      self._mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
          self._mesh_shape, self._layout_rules, self._gpu_devices)
Example #3
0
            def computation_fn():
                graph = mtf.Graph()
                mesh = mtf.Mesh(graph, 'my_mesh')
                mesh_shape = mtf.convert_to_shape('all:2')
                layout = 'none:all'
                mesh_devices = [''] * mesh_shape.size
                mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                    mesh_shape, mtf.convert_to_layout_rules(layout),
                    mesh_devices, device_assignment)
                hidden_dim = mtf.Dimension('hidden', 3)
                w = mtf.get_variable(mesh,
                                     'w',
                                     shape=[hidden_dim],
                                     initializer=tf.constant_initializer(
                                         [0.1, -0.2, -0.1]))
                x = mtf.constant(mesh, [0.4, 0.2, -0.5], [hidden_dim],
                                 dtype=tf.float32)
                loss = mtf.reduce_mean(mtf.square(x - w))

                lr, update_ops = optimization_lib.create_optimizer(
                    loss, 0.2, 100, 10)
                self.lowering = mtf.Lowering(graph, {mesh: mesh_impl})

                tf_update_ops = [
                    self.lowering.lowered_operation(op) for op in update_ops
                ]
                tf_update_ops.append(
                    tf.assign_add(tf.train.get_or_create_global_step(), 1))
                train_op = tf.group(tf_update_ops)

                return lr, train_op
Example #4
0
            def computation_fn():
                graph = mtf.Graph()
                mesh = mtf.Mesh(graph, 'my_mesh')
                mesh_shape = mtf.convert_to_shape('all:2')
                layout = 'num_heads:all'
                mesh_devices = [''] * mesh_shape.size
                mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                    mesh_shape, mtf.convert_to_layout_rules(layout),
                    mesh_devices, device_assignment)
                batch_dim = mtf.Dimension('batch', batch_size)
                seq_dim = mtf.Dimension('seq', seq_length)

                input_ids = tf.random.uniform((batch_size, seq_length),
                                              minval=0,
                                              maxval=vocab_size,
                                              dtype=tf.int32)
                mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids,
                                                     [batch_dim, seq_dim])

                model = bert_lib.BertModel(config=bert_config,
                                           is_training=True,
                                           input_ids=mtf_input_ids,
                                           input_mask=None,
                                           token_type_ids=None)
                pooled = model.get_pooled_output()
                lowering = mtf.Lowering(graph, {mesh: mesh_impl})
                return lowering.export_to_tf_tensor(pooled)
Example #5
0
  def testMinimizePeakMemoryList_ZeroUseTensor(self):
    mtf_graph = mtf.Graph()
    mesh = mtf.Mesh(mtf_graph, 'my_mesh')
    mtf.Constant(mesh, 0, shape=mtf.convert_to_shape('a:4'), dtype=tf.int32,
                 name='X')
    y = mtf.Constant(mesh, 0, shape=mtf.convert_to_shape('b:3'), dtype=tf.int32,
                     name='Y').outputs[0]
    mtf.BroadcastOperation(y, mtf.convert_to_shape('b:3,c:2'), name='Z')

    graph = graph_interface.GraphInterface(mtf_graph)
    schedule = list(scheduler.minimize_peak_memory(graph, 'LIST'))
    # When nothing is scheduled:
    #   X frees 0 entries
    #   Y frees -3 entries
    # Hence the schedule should be [X, Y, Z].
    self.assertEqual(schedule, [0, 1, 2])
Example #6
0
def model_fn(features, labels, mode, params):
  """A model is called by TpuEstimator."""
  del labels
  del features

  mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
  layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)

  ctx = params['context']
  num_hosts = ctx.num_hosts
  host_placement_fn = ctx.tpu_host_placement_function
  device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
  tf.logging.info('device_list = %s' % device_list, )

  mesh_devices = [''] * mesh_shape.size
  mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules,
                                              mesh_devices,
                                              ctx.device_assignment)

  graph = mtf.Graph()
  mesh = mtf.Mesh(graph, "fft_mesh")

  with mtf.utils.outside_all_rewrites():
    fsum = benchmark_model(mesh)
  lowering = mtf.Lowering(graph, {mesh: mesh_impl})

  tf_err = tf.to_float(lowering.export_to_tf_tensor(fsum))

  with mtf.utils.outside_all_rewrites():
    return tpu_estimator.TPUEstimatorSpec(mode, loss=tf_err)
Example #7
0
    def testLayout(self):
        # Construct a Mesh TensorFlow graph and mesh.
        mtf_graph = mtf.Graph()
        mesh = mtf.Mesh(mtf_graph, "my_mesh")
        x = mtf.zeros(mesh, "a:10,b:5")
        y = mtf.zeros(mesh, "b:5,c:20")
        z = mtf.einsum([x, y], "a:10,c:20")

        # Decide on a mesh shape.
        mesh_shape = mtf.convert_to_shape("m1:4,m2:2")

        # Compute a layout based on the graph and mesh.
        # Note that knowing the identity of the outputs is important to the
        # optimization since they cannot be freed.
        layout = mtf.auto_mtf.layout(mtf_graph, mesh_shape, [z])

        a_dim = mtf.convert_to_dimension(("a", 10))
        b_dim = mtf.convert_to_dimension(("b", 5))
        c_dim = mtf.convert_to_dimension(("c", 20))

        self.assertEqual(
            layout.tensor_dimension_to_mesh_axis(a_dim, mesh_shape), 1)
        self.assertIsNone(
            layout.tensor_dimension_to_mesh_axis(b_dim, mesh_shape))
        self.assertEqual(
            layout.tensor_dimension_to_mesh_axis(c_dim, mesh_shape), 0)
def main(_):

    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)

    # Resolve the cluster from SLURM environment
    cluster = tf.distribute.cluster_resolver.SlurmClusterResolver(
        {"mesh": mesh_shape.size // FLAGS.gpus_per_task},
        port_base=8822,
        gpus_per_node=FLAGS.gpus_per_node,
        gpus_per_task=FLAGS.gpus_per_task,
        tasks_per_node=FLAGS.tasks_per_node)

    cluster_spec = cluster.cluster_spec()
    # Create a server for all mesh members
    server = tf.distribute.Server(cluster_spec, "mesh", cluster.task_id)

    # Only he master job takes care of the graph building,
    # everyone else can just chill for now
    if cluster.task_id > 0:
        server.join()

    # Otherwise we are the main task, let's define the devices
    mesh_devices = [
        "/job:mesh/task:%d/device:GPU:%d" % (i, j)
        for i in range(cluster_spec.num_tasks("mesh"))
        for j in range(FLAGS.gpus_per_node)
    ]
    print("List of devices", mesh_devices)

    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "fft_mesh")

    # Build the model
    fft_err = benchmark_model(mesh)

    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        mesh_shape, layout_rules, mesh_devices)

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})

    # Retrieve output of computation
    result = lowering.export_to_tf_tensor(fft_err)

    with tf.Session(server.target) as sess:
        start = time.time()
        err = sess.run(result)
        end = time.time()

        time.sleep(1)
        start = time.time()
        err = sess.run(result)
        end = time.time()

    print("Max absolute FFT error %f, with wall time %f" % (err,
                                                            (end - start)))
    time.sleep(1)
    exit(0)
def get_placement_mesh(hparams):
  graph = mtf.Graph()
  mesh = mtf.Mesh(graph, "my_mesh")
  mesh_shape = mtf.convert_to_shape(hparams.mesh_shape)

  mesh_devices = [""] * mesh_shape.size
  mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
      mesh_shape, hparams.layout, mesh_devices)
  return mesh, mesh_impl
Example #10
0
def _tensor_dim_to_mesh_dim_size(hparams, tensor_dim):
  """Inspect hparams to figure out how many ways tensor_dim gets split."""
  layout_rules = mtf.convert_to_layout_rules(hparams.layout)
  mesh_shape = mtf.convert_to_shape(hparams.mesh_shape)
  mesh_axis = layout_rules.tensor_dimension_to_mesh_axis(tensor_dim, mesh_shape)
  if mesh_axis is None:
    return 1
  else:
    return mesh_shape.dims[mesh_axis].size
Example #11
0
 def testOptimizeLayoutTiebreak(self):
     x1 = mtf.zeros(self.mesh, "a:10,b:5")
     x2 = mtf.zeros(self.mesh, "b:5,c:20")
     mtf.einsum([x1, x2], "a:10,c:20")
     # Rewrite mesh_shape to have a dummy dimension.
     self.mesh_shape = mtf.convert_to_shape("m1:4,m2:2,m3:1")
     optimizer = self.get_layout_optimizer()
     layout = optimizer.solve()
     self.assertEqual(layout, "a:m2;b:m3;c:m1")
def get_placement_mesh(hparams):
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    mesh_shape = mtf.convert_to_shape(hparams.mesh_shape)

    mesh_devices = [""] * mesh_shape.size
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        mesh_shape, hparams.layout, mesh_devices)
    return mesh, mesh_impl
Example #13
0
    def testMinimizePeakMemoryList(self):
        mtf_graph = mtf.Graph()
        mesh = mtf.Mesh(mtf_graph, 'my_mesh')
        x = mtf.Constant(mesh,
                         0,
                         shape=mtf.convert_to_shape('a:3,b:4'),
                         dtype=tf.int32,
                         name='X').outputs[0]
        y = mtf.Constant(mesh,
                         0,
                         shape=mtf.convert_to_shape('b:4,c:5'),
                         dtype=tf.int32,
                         name='Y').outputs[0]
        mtf.EinsumOperation([x, y],
                            mtf.convert_to_shape('a:3,b:4,c:5'),
                            name='Z')
        w = mtf.EinsumOperation([x, y],
                                mtf.convert_to_shape('a:3,c:5'),
                                name='W').outputs[0]
        mtf.BroadcastOperation(w,
                               mtf.convert_to_shape('a:3,b:4,c:5'),
                               name='V')

        graph = graph_interface.GraphInterface(mtf_graph)
        graph.set_tensor_final('Z:0')
        graph.set_tensor_final('V:0')
        schedule = list(scheduler.minimize_peak_memory(graph, 'LIST'))

        # List Scheduler prefers to schedule things that free the most memory.
        # When nothing is scheduled:
        #   X frees -12 entries.
        #   Y frees -20 entries.
        # After [X] scheduled:
        #   Y frees -20 entries.
        # After [X, Y] scheduled:
        #   Z frees -60 entries.
        #   W frees -15 entries.
        # After [X, Y, W] scheduled:
        #   Z frees -28 entries.
        #   V frees -45 entries.
        # Hence the schedule should be [X, Y, W, Z, V].
        self.assertEqual(schedule, [0, 1, 3, 2, 4])
Example #14
0
  def testReturnsTopoSort(self, scheduler_alg):
    mtf_graph = mtf.Graph()
    mesh = mtf.Mesh(mtf_graph, 'my_mesh')
    x = mtf.Constant(mesh, 0,
                     shape=mtf.convert_to_shape('a:3,b:4'),
                     dtype=tf.int32,
                     name='X').outputs[0]
    y = mtf.Constant(mesh, 0,
                     shape=mtf.convert_to_shape('b:4,c:5'),
                     dtype=tf.int32,
                     name='Y').outputs[0]
    mtf.EinsumOperation([x, y], mtf.convert_to_shape('a:3,c:5'), name='Z1')
    mtf.EinsumOperation([x, y], mtf.convert_to_shape('a:3,c:5'), name='Z2')

    graph = graph_interface.GraphInterface(mtf_graph)
    graph.set_tensor_final('Z1:0')
    graph.set_tensor_final('Z2:0')
    schedule = list(scheduler.minimize_peak_memory(graph, scheduler_alg))

    self.assertCountEqual(schedule[0:2], [0, 1])
    self.assertCountEqual(schedule[2:4], [2, 3])
Example #15
0
def run_cifar():
    """Run MNIST training and eval loop."""
    cifar_classifier = tf.estimator.Estimator(model_fn=model_fn,
                                              model_dir=FLAGS.model_dir)
    dataset = cifar_dset()

    # Set up training and evaluation input functions.
    def train_input_fn():
        """Prepare data for training."""

        # When choosing shuffle buffer sizes, larger sizes result in better
        # randomness, while smaller sizes use less memory. MNIST is a small
        # enough dataset that we can easily shuffle the full epoch.
        ds = dataset.train(FLAGS.data_dir)
        ds_batched = ds.cache().shuffle(buffer_size=50000).batch(
            FLAGS.batch_size)

        # Iterate through the dataset a set number (`epochs_between_evals`) of times
        # during each training session.
        ds = ds_batched.repeat(FLAGS.epochs_between_evals)
        return ds

    def eval_input_fn():
        return dataset.test(FLAGS.data_dir).batch(
            FLAGS.batch_size).make_one_shot_iterator().get_next()

    # Train and evaluate model.
    import time
    time_tot_start = 0
    time_epoch_start = 0
    time_tot_start = time.time()
    f = open("./Het_CNN.txt", "a+")
    f.write("#Filters\t#Epochs\t#Time\t#Accuracy\t#Loss\t#Shape\n")
    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    mesh_size = mesh_shape.size
    conv_shape = []

    for ep in range(FLAGS.train_epochs // FLAGS.epochs_between_evals):
        time_epoch_start = time.time()
        cifar_classifier.train(input_fn=train_input_fn, hooks=None)
        time_epoch_end = time.time() - time_epoch_start
        eval_results = cifar_classifier.evaluate(input_fn=eval_input_fn)
        print("\nEvaluation results:\n\t%s\n" % eval_results)
        print(ep, "----------->", time_epoch_end)
        f.write("%d\t%0.6f\t%0.6f\t%0.6f\t%s\n" %
                (ep, time_epoch_end, eval_results['accuracy'],
                 eval_results['loss'], conv_shape))

    time_tot_end = time.time() - time_tot_start
    print("Total Time ", FLAGS.train_epochs, " Epochs", time_tot_end)

    f.close()
Example #16
0
def main(_):

    #layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)
    mesh_shape = [("row", FLAGS.nx), ("col", FLAGS.ny)]
    layout_rules = [("nx_lr", "row"), ("ny_lr", "col"), ("nx", "row"),
                    ("ny", "col"), ("ty", "row"), ("tz", "col"),
                    ("ty_lr", "row"), ("tz_lr", "col"), ("nx_block", "row"),
                    ("ny_block", "col")]

    mesh_impl = HvdSimdMeshImpl(mtf.convert_to_shape(mesh_shape),
                                mtf.convert_to_layout_rules(layout_rules))

    # Build the model
    # Create computational graphs and some initializations
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "nbody_mesh")

    initial_conditions, mesh_final_field = lpt_prototype(
        mesh, bs=FLAGS.box_size, nc=FLAGS.nc, batch_size=FLAGS.batch_size)

    # Lower mesh computation
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})

    # Retrieve output of computation
    initc = lowering.export_to_tf_tensor(initial_conditions)
    result = lowering.export_to_tf_tensor(mesh_final_field)

    with tf.Session() as sess:
        start = time.time()
        a, c = sess.run([initc, result])
        end = time.time()
        ttime = (end - start)
        print('Time for ', mesh_shape, ' is : ', ttime)

    if comm.rank == 0:
        plt.figure(figsize=(9, 3))
        plt.subplot(121)
        plt.imshow(a[0].sum(axis=2))
        plt.title('Initial Conditions')

        plt.subplot(122)
        plt.imshow(c[0].sum(axis=2))
        plt.title('Mesh TensorFlow')
        plt.colorbar()
        plt.savefig("mesh_nbody_%d-row:%d-col:%d.png" %
                    (FLAGS.nc, FLAGS.nx, FLAGS.ny))
        plt.close()

    exit(0)
Example #17
0
def layout(mtf_graph, mesh_shape, mtf_outputs=()):
  """Compute layout rules based on a computational graph and mesh shape.

  Args:
    mtf_graph: a mtf.Graph.
    mesh_shape: an mtf.Shape, str, or listlike of mtf.Dimension.
    mtf_outputs: an optional iterable of mtf.Tensor, representing the outputs
        of the computation.

  Returns:
    a mtf.LayoutRules
  """
  mesh_shape = mtf.convert_to_shape(mesh_shape)
  estimator = memory_estimator.MemoryEstimator(mtf_graph, mesh_shape,
                                               mtf_outputs)
  optimizer = layout_optimizer.LayoutOptimizer(estimator)
  return mtf.convert_to_layout_rules(optimizer.solve())
Example #18
0
def main(_):

  tf.logging.set_verbosity(tf.logging.INFO)
  mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)

  # Resolve the TPU environment
  tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
      FLAGS.tpu,
      zone=FLAGS.tpu_zone,
      project=FLAGS.gcp_project
  )

  run_config = tf.estimator.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      model_dir=FLAGS.model_dir,
      save_checkpoints_steps=None,  # Disable the default saver
      save_checkpoints_secs=None,  # Disable the default saver
      log_step_count_steps=100,
      save_summary_steps=100,
      tpu_config=tpu_config.TPUConfig(
          num_shards=mesh_shape.size,
          iterations_per_loop=100,
          num_cores_per_replica=1,
          per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST))

  model = tpu_estimator.TPUEstimator(
      use_tpu=True,
      model_fn=model_fn,
      config=run_config,
      predict_batch_size=1,
      train_batch_size=FLAGS.batch_size,
      eval_batch_size=FLAGS.batch_size)

  def dummy_input_fn(params):
    dset = tf.data.Dataset.from_tensor_slices(tf.zeros(shape=[params['batch_size'],1],
                                                       dtype=tf.float32))
    return dset

  # Run evaluate loop for ever, we will be connecting to this process using a profiler
  for i, f in enumerate(model.predict(input_fn=dummy_input_fn)):
    print(i)
    np.save(file_io.FileIO(FLAGS.output_dir+'/field_%d.npy'%i, 'w'), f['field'])
Example #19
0
def model_fn(features, labels, mode, params):
    """A model is called by TpuEstimator."""
    del labels
    del features

    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)

    ctx = params['context']
    num_hosts = ctx.num_hosts
    host_placement_fn = ctx.tpu_host_placement_function
    device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
    tf.logging.info('device_list = %s' % device_list, )

    mesh_devices = [''] * mesh_shape.size
    mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules,
                                                mesh_devices,
                                                ctx.device_assignment)

    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "fft_mesh")

    with mtf.utils.outside_all_rewrites():
        field = nbody_model(mesh)
        batch_dim, x_dim, y_dim, z_dim = field.shape
        x_dim_nosplit = mtf.Dimension("nx_nosplit", FLAGS.cube_size)
        y_dim_nosplit = mtf.Dimension("ny_nosplit", FLAGS.cube_size)

        # Until we implement distributed outputs, we only return one example
        field_slice, _ = mtf.split(field, batch_dim, [1, FLAGS.batch_size - 1])
        field_slice = mtf.reshape(
            field_slice,
            [mtf.Dimension("bs", 1), x_dim_nosplit, y_dim_nosplit, z_dim])
        #field_slice = field

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    tf_field = tf.to_float(lowering.export_to_tf_tensor(field_slice))

    with mtf.utils.outside_all_rewrites():
        return tpu_estimator.TPUEstimatorSpec(mode,
                                              predictions={'field': tf_field})
Example #20
0
def run_toy_model_tpu():
    """Run a toy model on TPU."""
    tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    iterations_per_loop = FLAGS.iterations
    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    config = tpu_config.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=None,  # Disable the default saver
        save_checkpoints_secs=None,  # Disable the default saver
        log_step_count_steps=iterations_per_loop,
        save_summary_steps=iterations_per_loop,
        tpu_config=tpu_config.TPUConfig(
            num_shards=mesh_shape.size,
            iterations_per_loop=iterations_per_loop,
            num_cores_per_replica=1,
            per_host_input_for_training=tpu_config.InputPipelineConfig.
            BROADCAST))
    classifier = tpu_estimator.TPUEstimator(use_tpu=True,
                                            model_fn=model_fn,
                                            config=config,
                                            train_batch_size=FLAGS.batch_size,
                                            eval_batch_size=FLAGS.batch_size)
    current_step = estimator_lib._load_global_step_from_checkpoint_dir(
        FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long
    logging.info('Current step %d', current_step)
    if FLAGS.steps_per_checkpoint == 0:
        classifier.train(input_fn=ToyModelInput(), max_steps=FLAGS.train_steps)
        return
    while current_step < FLAGS.train_steps:
        next_checkpoint = min(current_step + FLAGS.steps_per_checkpoint,
                              FLAGS.train_steps)
        classifier.train(input_fn=ToyModelInput(), max_steps=next_checkpoint)
        current_step = next_checkpoint
        logging.info('Starting to evaluate.')
        eval_results = classifier.evaluate(
            input_fn=ToyModelInput(), steps=156
        )  # since we have 10000 examples and batch_size = 64 per host
        logging.info('Eval results: %s', eval_results)
Example #21
0
def main(_):

  # Creating layout and mesh implementation
  mesh_shape = [("row", FLAGS.nx), ("col", FLAGS.ny)]
  layout_rules = [("nx_lr", "row"), ("ny_lr", "col"), ("nx", "row"),
                  ("ny", "col"), ("ty", "row"), ("tz", "col"), ("ty_lr", "row"),
                  ("tz_lr", "col"), ("nx_block", "row"), ("ny_block", "col")]
  mesh_impl = HvdSimdMeshImpl(
      mtf.convert_to_shape(mesh_shape),
      mtf.convert_to_layout_rules(layout_rules))

  # Create the graph and mesh
  graph = mtf.Graph()
  mesh = mtf.Mesh(graph, "my_mesh")

  ## Load initial power spectrum
  klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
  plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]

  # Defines the computational graph for the nbody
  initial_conditions, final_field = nbody_fn(mesh, klin, plin)

  # Lower mesh computation
  lowering = mtf.Lowering(graph, {mesh: mesh_impl})

  # Retrieve fields as tf tensors
  tf_initc = lowering.export_to_tf_tensor(initial_conditions)
  tf_final = lowering.export_to_tf_tensor(final_field)

  with tf.Session() as sess:
    start = time.time()
    init_conds, final = sess.run([tf_initc, tf_final])
    end = time.time()
    print('\n Time for the mesh run : %f \n' % (end - start))

  # Export these fields
  np.save('simulation_output_%d.npy' % comm.Get_rank(), final)
  np.save('simulation_input_%d.npy' % comm.Get_rank(), init_conds)

  exit(0)
def main(_):

  tf.logging.set_verbosity(tf.logging.INFO)
  mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)

  # Resolve the TPU environment
  tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
      FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

  run_config = tf.estimator.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      model_dir=FLAGS.model_dir,
      save_checkpoints_steps=None,  # Disable the default saver
      save_checkpoints_secs=None,  # Disable the default saver
      log_step_count_steps=100,
      save_summary_steps=100,
      tpu_config=tpu_config.TPUConfig(
          num_shards=mesh_shape.size,
          iterations_per_loop=100,
          num_cores_per_replica=1,
          per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST))

  model = tpu_estimator.TPUEstimator(
      use_tpu=True,
      model_fn=model_fn,
      config=run_config,
      train_batch_size=FLAGS.batch_size,
      eval_batch_size=FLAGS.batch_size)

  def dummy_input_fn(params):
    """Dummy input function """
    return tf.zeros(
        shape=[params['batch_size']], dtype=tf.float32), tf.zeros(
            shape=[params['batch_size']], dtype=tf.float32)

  # Run evaluate loop for ever, we will be connecting to this process using a profiler
  model.evaluate(input_fn=dummy_input_fn, steps=10000)
Example #23
0
def train_and_eval():
    """Trains and evaluates MeshTensorflow model without TPUEstimator.

  TODO(lehou): Pack everything nicely as a set of APIs.
  """
    tf.logging.info('FLAGS.master: {}'.format(FLAGS.master))

    # Open a session to get the list of CPU devices to hold master variables.
    with tf.Session(target=FLAGS.master,
                    config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        topology = sess.run(tpu.initialize_system())
        cpu_devices = _list_cpu_devices(sess)

    topo_object = tf.contrib.tpu.Topology(serialized=topology)
    num_cores = int(np.prod(topo_object.mesh_shape))
    num_hosts = int(topo_object.num_tasks)
    num_cores_per_host = int(num_cores // num_hosts)
    assert num_cores_per_host == int(topo_object.num_tpus_per_task)

    # Get a device_assignment object for mtf.
    d_assignment = device_assignment.device_assignment(
        topology, computation_shape=[1, 1, 1], num_replicas=num_cores)

    # Get mesh_impl.
    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    layout_rules = unet.get_layout()
    mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules, None,
                                                d_assignment)

    for _ in range(FLAGS.num_training_loops):
        _train_phase(mesh_impl, cpu_devices, d_assignment, num_hosts,
                     num_cores)
        _eval_phase(mesh_impl, cpu_devices, d_assignment, num_hosts, num_cores)

    _shutdown()

    tf.logging.info('finished.')
Example #24
0
def model_fn(features, labels, mode, params):
    """The model_fn argument for creating an Estimator."""
    tf.logging.info("features = %s labels = %s mode = %s params=%s" %
                    (features, labels, mode, params))
    global_step = tf.train.get_global_step()
    graph = mtf.Graph()
    # wrapped graph named "my_mesh"
    mesh = mtf.Mesh(graph, "my_mesh")
    logits, loss = mnist_model(features, labels, mesh)
    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)
    mesh_size = mesh_shape.size
    print("mesh_shape.size = ", mesh_shape.size)
    mesh_devices = [""] * mesh_size
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        mesh_shape, layout_rules, mesh_devices)

    if mode == tf.estimator.ModeKeys.TRAIN:
        var_grads = mtf.gradients(
            [loss], [v.outputs[0] for v in graph.trainable_variables])
        optimizer = mtf.optimize.AdafactorOptimizer()
        update_ops = optimizer.apply_grads(var_grads,
                                           graph.trainable_variables)

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    restore_hook = mtf.MtfRestoreHook(lowering)

    tf_logits = lowering.export_to_tf_tensor(logits)
    if mode != tf.estimator.ModeKeys.PREDICT:
        tf_loss = lowering.export_to_tf_tensor(loss)
        tf.summary.scalar("loss", tf_loss)

    if mode == tf.estimator.ModeKeys.TRAIN:
        tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
        tf_update_ops.append(tf.assign_add(global_step, 1))
        train_op = tf.group(tf_update_ops)
        saver = tf.train.Saver(tf.global_variables(),
                               sharded=True,
                               max_to_keep=10,
                               keep_checkpoint_every_n_hours=2,
                               defer_build=False,
                               save_relative_paths=True)
        tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
        saver_listener = mtf.MtfCheckpointSaverListener(lowering)
        saver_hook = tf.train.CheckpointSaverHook(FLAGS.model_dir,
                                                  save_steps=1000,
                                                  saver=saver,
                                                  listeners=[saver_listener])

        accuracy = tf.metrics.accuracy(labels=labels,
                                       predictions=tf.argmax(tf_logits,
                                                             axis=1))

        # Name tensors to be logged with LoggingTensorHook.
        tf.identity(tf_loss, "cross_entropy")
        tf.identity(accuracy[1], name="train_accuracy")

        # Save accuracy scalar to Tensorboard output.
        tf.summary.scalar("train_accuracy", accuracy[1])

        # restore_hook must come before saver_hook
        return tf.estimator.EstimatorSpec(
            tf.estimator.ModeKeys.TRAIN,
            loss=tf_loss,
            train_op=train_op,
            training_chief_hooks=[restore_hook, saver_hook])

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            "classes": tf.argmax(tf_logits, axis=1),
            "probabilities": tf.nn.softmax(tf_logits),
        }
        return tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.PREDICT,
            predictions=predictions,
            prediction_hooks=[restore_hook],
            export_outputs={
                "classify": tf.estimator.export.PredictOutput(predictions)
            })
    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.EVAL,
            loss=tf_loss,
            evaluation_hooks=[restore_hook],
            eval_metric_ops={
                "accuracy":
                tf.metrics.accuracy(labels=labels,
                                    predictions=tf.argmax(tf_logits, axis=1)),
            })
  def __init__(
      self,
      model_dir,
      tpu,
      tpu_job_name=None,
      tpu_zone=None,
      gcp_project=None,
      tpu_topology="v2-8",
      model_parallelism=8,
      batch_size=("sequences_per_batch", 1),
      sequence_length=None,
      model_type="bitransformer",
      layout_rules="ensemble:ensemble,batch:batch,d_ff:model,heads:model,vocab:model,experts:batch",
      mesh_shape=None,
      mesh_devices=None,
      autostack=True,
      learning_rate_schedule=None,
      keep_checkpoint_max=None,
      save_checkpoints_steps=5000,
      optimizer=None,
      predict_fn=None,
      variable_filter=None,
      ensemble_inputs=None,
      iterations_per_loop=100):
    """Constructor for MtfModel class.

    Args:
      model_dir: str, directory to save the model.
      tpu: str, the TPU address to use.
      tpu_job_name: str, name of the TPU worker binary.
      tpu_zone: str, GCE zone where the Cloud TPU is located
      gcp_project: str, project name for the Cloud TPU-enabled project.
      tpu_topology: str, e.g. "2x2" or "v2-8".
      model_parallelism: integer, the number of cores per model replica.
      batch_size: An integer or a (method, value) pair to pass to
        compute_batch_size(). Note that this is the global batch size and not
        the per-shard batch size.
      sequence_length: an integer or a dict from feature-key to integer
        the (packed) sequence length, e.g. {"inputs": 512, "targets": 128}
      model_type: str, a model type from mesh tf models.
      layout_rules: an input to mtf.convert_to_layout_rules()
      mesh_shape: an mtf.Shape or string (e.g., "model:2,batch:4") specifying
        how the data/model should be parallelized. If None (default), the mesh
        shape will be constructed using the supplied `tpu_topology` and
        `model_parallelism` arguments.
      mesh_devices: a list of strings, the device names to use for each mesh
        slice. Only required for GPU.
      autostack: boolean, internally combine variables.
      learning_rate_schedule: an optional function taking the scalar name
        argument `step` and the numeric argument `total_train_steps` and return
        the scalar learning rate.
      keep_checkpoint_max: an integer, maximum number of checkpoints to keep.
      save_checkpoints_steps: an integer, steps per checkpoint.
      optimizer: a class extending optimize.Optimizer, required for training.
      predict_fn: an optional function that can be used to override the default
        transformer prediction behavior. Must return a tensor of shape
        [batch_dim, length_dim] that will be the prediction for each example.
        Must accept the following arguments:
          - model: a Unitransformer or Bitransformer
          - features: a dict representing an example. Every value will be an
            mtf.Tensor with shape [batch_dim, length_dim].
          - variable_dtype: an mtf.VariableDType
      variable_filter: a str, a variable will only be trained if its name
        matches this regex. If None (default), train all trainable variables.
      ensemble_inputs: an integer, see `train_model` docstring for details.
      iterations_per_loop: integer, steps per train loop
    """
    mesh_shape = mesh_shape or (
        utils.tpu_mesh_shape(tpu_topology, model_parallelism) if tpu else "")

    sequence_length = sequence_length or {"inputs": 512, "targets": 512}

    if isinstance(sequence_length, int):
      sequence_length = {"inputs": sequence_length,
                         "targets": sequence_length}
    self._learning_rate_schedule = (
        learning_rate_schedule or
        learning_rate_schedules.learning_rate_schedule_noam)

    self._optimizer = optimizer or optimize.AdafactorOptimizer

    self._sequence_length = sequence_length
    self._model_dir = model_dir
    self._model_type = model_type
    self._ensemble_inputs = ensemble_inputs

    self._layout_rules = mtf.convert_to_layout_rules(layout_rules)
    self._mesh_shape = mtf.convert_to_shape(mesh_shape)
    self._mesh_devices = mesh_devices

    self._autostack = autostack
    self._keep_checkpoint_max = keep_checkpoint_max
    self._save_checkpoints_steps = save_checkpoints_steps
    self._predict_fn = predict_fn
    self._variable_filter = variable_filter
    self._ensemble_inputs = ensemble_inputs
    self._iterations_per_loop = iterations_per_loop

    self._cluster = tf.distribute.cluster_resolver.TPUClusterResolver(
        tpu, zone=tpu_zone, project=gcp_project) if tpu else None
    self._tpu = tpu
    self._tpu_job_name = tpu_job_name
    self._estimator = None

    # Must be called after _sequence_length, _mesh_shape, and _layout_rules are
    # set.
    self.batch_size = batch_size
Example #26
0
def model_fn(features, labels, mode, params):
  """A model is called by TpuEstimator."""
  del labels
  global_step = tf.train.get_global_step()
  graph = mtf.Graph()
  mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
  layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)
  if FLAGS.use_tpu:
    ctx = params['context']
    num_hosts = ctx.num_hosts
    host_placement_fn = ctx.tpu_host_placement_function
    device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
    tf.logging.info('device_list = %s' % device_list,)
    # TODO(ylc): Better estimation of replica cache size?
    replica_cache_size = 300 * 1000000  # 300M per replica
    # Worker 0 caches all the TPU binaries.
    worker0_mem = replica_cache_size * ctx.num_replicas
    devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
    var_placer = mtf.utils.BalancedVariablePlacer(device_list,
                                                  devices_memeory_usage)
    mesh_devices = [''] * mesh_shape.size
    mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
        mesh_shape, layout_rules, mesh_devices, ctx.device_assignment)
  else:
    var_placer = None
    mesh_devices = [''] * mesh_shape.size
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        mesh_shape, layout_rules, mesh_devices)
  mesh = mtf.Mesh(graph, 'my_mesh', var_placer)

  with mtf.utils.outside_all_rewrites():
    logits, loss = toy_model(features, mesh)

  # TRAIN mode
  if mode == tf.estimator.ModeKeys.TRAIN:
    var_grads = mtf.gradients([loss],
                              [v.outputs[0] for v in graph.trainable_variables])
    if FLAGS.optimizer == 'Adafactor':
      optimizer = mtf.optimize.AdafactorOptimizer()
    else:
      assert FLAGS.optimizer == 'SGD'
      optimizer = mtf.optimize.SgdOptimizer(lr=FLAGS.lr)
    update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables)
  else:
    # for now, we can only export fully-replicated tensors.
    fully_replicated_logits = mtf.anonymize(logits)

  lowering = mtf.Lowering(graph, {mesh: mesh_impl})

  tf_loss = tf.to_float(lowering.export_to_tf_tensor(loss))

  if mode == tf.estimator.ModeKeys.TRAIN:
    tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
    tf_update_ops.append(tf.assign_add(global_step, 1))
    tf.logging.info('tf_update_ops: {}'.format(tf_update_ops))
    train_op = tf.group(tf_update_ops)
  else:
    tf_logits = lowering.export_to_tf_tensor(fully_replicated_logits)

  with mtf.utils.outside_all_rewrites():
    # Copy master variables to slices. Must be called first.
    restore_hook = mtf.MtfRestoreHook(lowering)
    if mode == tf.estimator.ModeKeys.TRAIN:
      saver = tf.train.Saver(
          tf.global_variables(),
          sharded=True,
          max_to_keep=10,
          keep_checkpoint_every_n_hours=2,
          defer_build=False,
          save_relative_paths=True)
      tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
      saver_listener = mtf.MtfCheckpointSaverListener(lowering)
      saver_hook = tf.train.CheckpointSaverHook(
          FLAGS.model_dir,
          save_steps=1000,
          saver=saver,
          listeners=[saver_listener])

      return tpu_estimator.TPUEstimatorSpec(
          tf.estimator.ModeKeys.TRAIN,
          loss=tf_loss,
          train_op=train_op,
          training_hooks=[restore_hook, saver_hook])
    elif mode == tf.estimator.ModeKeys.EVAL:

      def metric_fn(tf_logits):
        mean_logits = tf.metrics.mean(tf_logits)
        return {'mean_logits': mean_logits}

      eval_metrics = (metric_fn, [tf_logits])

      return tpu_estimator.TPUEstimatorSpec(
          tf.estimator.ModeKeys.EVAL,
          evaluation_hooks=[restore_hook],
          loss=tf_loss,
          eval_metrics=eval_metrics)
Example #27
0
    def test_get_laidout_tensors(self, is_eval_mode):
        mesh_shape = "mesh_x:2, mesh_y:1"
        layout = "batch:mesh_x, io:mesh_y"
        batch_io_dim = 4

        with tf.Session() as sess:
            topology, num_cores = self.initialize_system(sess)

            # Get a device_assignment object for mtf.
            d_assignment = device_assignment.device_assignment(
                topology,
                computation_shape=[
                    1,
                ] * mtf.utils.topology_rank(topology),
                num_replicas=num_cores)

            # Hacked dataset creator: creates different datasets for the first and
            # second call, in order to test SimdMeshImplInputReader.
            self.sub_batch_created_times = 0

            def stateful_ds_creator():
                whole_batch = tf.eye(batch_io_dim, dtype=tf.float32)
                sub_batch = tf.slice(whole_batch,
                                     [self.sub_batch_created_times * 2, 0],
                                     [2, 4])
                self.sub_batch_created_times += 1
                return tf.data.Dataset.from_tensors(
                    sub_batch).repeat().unbatch()

            batch_dim = mtf.Dimension("batch", batch_io_dim)
            io_dim = mtf.Dimension("io", batch_io_dim)
            mtf_input_shapes = [mtf.Shape([batch_dim, io_dim])]

            # Get mesh_impl.
            mesh_shape = mtf.convert_to_shape(mesh_shape)
            layout_rules = mtf.convert_to_layout_rules(layout)
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape, layout_rules, None, d_assignment)

            simd_input_reader = input_reader.SimdMeshImplInputReader(
                mesh_impl,
                stateful_ds_creator,
                mtf_input_shapes,
                external_worker=False,
                is_eval_mode=is_eval_mode)

            def model_fn(features):
                return features

            replicated_computation = tpu.replicate(
                computation=model_fn,
                inputs=[[]] * num_cores,
                infeed_queue=simd_input_reader.infeed_queue,
                device_assignment=d_assignment)

            simd_input_reader.start_infeed_thread(sess, 1)
            results = sess.run(replicated_computation)
            print("results: {}".format(results))

            core_0_data = results[0][0]
            core_1_data = results[1][0]
            print("core_0_data: {}".format(core_0_data))
            print("core_1_data: {}".format(core_1_data))

            if is_eval_mode:
                # If there is only one dataset object, then the stateful_ds_creator()
                # should be called only once.
                self.assertAllClose(
                    np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32),
                    core_0_data)
                self.assertAllClose(
                    np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32),
                    core_1_data)
            else:
                # If there are two dataset objects, then the stateful_ds_creator()
                # should be called twice.
                self.assertAllClose(
                    np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32),
                    core_0_data)
                self.assertAllClose(
                    np.array([[0, 0, 1, 0], [0, 0, 0, 1]], dtype=np.float32),
                    core_1_data)

            sess.run(tf.tpu.shutdown_system())
Example #28
0
def model_fn(features, labels, mode, params):
    # Get global step
    global_step = tf.train.get_global_step()

    # Construct mtf graph + mesh from params
    graph = mtf.Graph()
    mesh_shape = mtf.convert_to_shape(params["mesh_shape"])
    layout_rules = mtf.convert_to_layout_rules(params["layout"])

    # Mesh setup
    if params["use_tpu"]:
        var_placer, mesh_impl = simd_mesh_setup(params, mesh_shape,
                                                layout_rules)
    else:
        var_placer = None
        gpu_ids = params["gpu_ids"]
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
            mesh_shape, layout_rules, gpu_ids)

    # Trainable variable precision
    # Store to checkpoints in master type, train in slice type, compute in activation type
    if params["precision"] == "bfloat16":
        variable_dtype = mtf.VariableDType(master_dtype=tf.bfloat16,
                                           slice_dtype=tf.float32,
                                           activation_dtype=tf.bfloat16)
    else:
        variable_dtype = mtf.VariableDType(master_dtype=tf.float32,
                                           slice_dtype=tf.float32,
                                           activation_dtype=tf.float32)

    # Build mtf mesh object
    mesh = mtf.Mesh(graph, "my_mesh", var_placer)

    # Build mtf_features & seq length dict for getting number of microbatches
    # We need to pack inputs into a dict to pass into serialize_training_step
    features_dict = {"inputs": features, "labels": labels}
    sequence_length_dict = {
        "inputs": params["n_ctx"],
        "labels": params["n_ctx"]
    }

    params = add_mode_to_params(params, mode)
    batch_size = get_batch_size(params)

    batch_dim = mtf.Dimension("batch", batch_size)
    batch_dims = [batch_dim]
    feature_length = sequence_length_dict["inputs"]
    length_dim = mtf.Dimension("sequence", feature_length)

    mtf_features = {}
    for key, x in features_dict.items():
        if x is not None:
            feature_shape = mtf.Shape(batch_dims + [length_dim])
            if type(features_dict[key]) == dict:
                features_dict[key] = features_dict[key]["feature"]
            x = tf.cast(features_dict[key], tf.int32)
            x = tf.reshape(x, feature_shape.to_integer_list)
            mtf_features[key] = mtf.import_fully_replicated(mesh,
                                                            x,
                                                            feature_shape,
                                                            name=key)

    # Instantiate dict for dimensions, bias, etc that can be calculated here once then passed into model
    other_features = {}
    memory_length_dim = mtf.Dimension("memory_length", length_dim.size)

    attn_bias = biasmask_attn_weights(
        mesh, length_dim, memory_length_dim,
        variable_dtype) if params["causal"] else None

    # Add attn_bias into mtf_features
    other_features["attn_bias"] = attn_bias

    # Define other Dimensions that we'll need inside the model
    embd_dim = mtf.Dimension("embd", params["n_embd"])
    vocab_dim = mtf.Dimension("vocab", params["n_vocab"])
    # We need this because gathering when both the args have the same dimension in them breaks things
    # This dim is specifically for the weights
    # This prevents the "Einsum has lhs dimension without corresponding rhs or output dimension." error
    embed_sequence_dim = mtf.Dimension("embed_sequence", params["n_ctx"])

    other_features["embd_dim"] = embd_dim
    other_features["vocab_dim"] = vocab_dim
    other_features["embed_sequence_dim"] = embed_sequence_dim
    other_features["memory_length_dim"] = memory_length_dim

    if mode == tf.estimator.ModeKeys.PREDICT:
        # Set up the model for prediction
        inputs = mtf_features["inputs"]
        if params["remove_partial_sequences"] is None:
            params["remove_partial_sequences"] = False

        export = params.get("export", False)

        if not export:
            mtf_samples = sample_autoregressive(
                inputs,
                other_features=other_features,
                params=params,
                variable_dtype=variable_dtype,
                remove_partial_sequences=params["remove_partial_sequences"],
                stop_at_token=params["eos_id"],
                sampling_use_entmax=params['sampling_use_entmax'])

        else:
            with mtf.utils.outside_all_rewrites():
                with tf.variable_scope('gpt2'):
                    mtf_samples, loss, loss_batch = gpt2.model(
                        mtf_features,
                        other_features,
                        params,
                        mesh,
                        variable_dtype=variable_dtype,
                        context=None)

        mtf_samples = mtf.anonymize(mtf_samples)
        inputs = mtf.anonymize(inputs)
        lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True)
        inputs = lowering.export_to_tf_tensor(inputs)
        outputs = lowering.export_to_tf_tensor(mtf_samples)
        predictions = {"inputs": inputs, "outputs": outputs}

        def scaffold_fn():
            return tf.train.Scaffold(
                local_init_op=tf.group(
                    tf.train.Scaffold.default_local_init_op(),
                    lowering.copy_masters_to_slices(),
                    name="mtf_local_init_op"),
                ready_op=tf.concat([
                    tf.report_uninitialized_variables(),
                    resources.report_uninitialized_resources()
                ],
                                   axis=0,
                                   name="mtf_ready_op"))

        return tpu_estimator.TPUEstimatorSpec(
            mode=tf.estimator.ModeKeys.PREDICT,
            predictions=predictions,
            scaffold_fn=scaffold_fn,
            prediction_hooks=[mtf.MtfRestoreHook(lowering)])

    # We're not predicting, so we better be training or evaluating
    assert (mode == tf.estimator.ModeKeys.TRAIN
            or mode == tf.estimator.ModeKeys.EVAL)

    if mode == tf.estimator.ModeKeys.TRAIN:
        # Gets number of microbatches per batch for serialized training
        # if param tokens_per_mb_per_replica = None, this defaults to 1 and no microbatching is performed
        num_microbatches = int(
            mtf_transformer.utils.serialize_num_microbatches(
                batch_dim=batch_dim,
                sequence_length=sequence_length_dict,
                mesh_shape=mesh_shape,
                layout_rules=layout_rules,
                tokens_per_microbatch_per_replica=params[
                    "tokens_per_mb_per_replica"]))
    else:
        num_microbatches = 1

    params[
        "num_microbatches"] = num_microbatches  # Add num microbatches to params

    if num_microbatches > 1:

        # For serialize_training_step we need to modify the model to output results in a dict
        def serialized_fn(mtf_features):
            if params["model"] == "GPT":
                with tf.variable_scope('gpt2'):
                    logits, loss, loss_batch = gpt2.model(
                        mtf_features,
                        other_features,
                        params,
                        mesh,
                        variable_dtype=variable_dtype)
                return {
                    "logits": logits,
                    "loss": loss,
                    "loss_batch": loss_batch
                }
            else:
                raise Exception(
                    f"'{params['model']}' is not a valid model - please select from [GPT]"
                )

        # Serialize the training step - Gradients are accumulated locally and reduced once.
        var_grads, output_dict = mtf.serialize_training_step(
            mtf_features, serialized_fn, batch_dim, num_microbatches)
        loss = output_dict["loss"]
        loss_batch = output_dict["loss_batch"]
        logits = output_dict["logits"]
    else:
        # If we're not splitting into microbatches, return logits & loss as is
        if params["model"] == "GPT":
            with mtf.utils.outside_all_rewrites():
                with tf.variable_scope('gpt2'):
                    logits, loss, loss_batch = gpt2.model(
                        mtf_features,
                        other_features,
                        params,
                        mesh,
                        variable_dtype=variable_dtype,
                        context=None)
        else:
            raise Exception(
                f"'{params['model']}' is not a valid model - please select from [GPT]"
            )

    # Auto layout generation
    if params["auto_layout"]:
        auto_layout(graph, mesh_shape, logits, loss)
    if params["auto_layout_and_mesh_shape"]:
        auto_layout_and_mesh_shape(graph, params["num_cores"], logits, loss)

    if mode == tf.estimator.ModeKeys.TRAIN:
        # In TRAIN mode, get optimizer
        if params["num_microbatches"] > 1:
            # If we are splitting the batch into microbatches, var grads are created in the serialize_training_step fn
            # So we pass them in here
            _, update_ops, var_grads = get_optimizer(
                mesh,
                loss,
                params,
                variable_dtype=variable_dtype,
                inp_var_grads=var_grads)
        else:
            # Otherwise, they are created in the get_optimizer fn, so we leave inp_var_grads blank
            _, update_ops, var_grads = get_optimizer(
                mesh, loss, params, variable_dtype=variable_dtype)
        # Log summaries to tensorboard
        mtf.scalar_summary("loss", loss)
        # Log gradients if in params
        if params["log_grads"] not in [None, False]:
            for g in var_grads:
                grad_norm = mtf.sqrt(mtf.reduce_sum(mtf.square(g)))
                mtf.scalar_summary("grads/norm" + g.name[:-2], grad_norm)
    else:
        # For now, we can only export fully-replicated tensors.
        # This has to be done before lowering or they will not be included in the graph
        mean_logits = mtf.reduce_mean(logits, reduced_dim=vocab_dim)
        max_logits = mtf.argmax(logits, vocab_dim)
        del logits
        fully_replicated_mean_logits = mtf.anonymize(mean_logits)
        fully_replicated_max_logits = mtf.anonymize(max_logits)
        fully_replicated_loss_batch = mtf.anonymize(loss_batch)

    # Gets & prints info about no. trainable vars in the model & dimension names
    get_graph_info(graph)

    # 'lowers' mtf tensors into a tf graph - this enables us to export results as tf tensors
    lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True)
    tf_loss = lowering.export_to_tf_tensor(loss)
    tf_loss = tf.cast(tf_loss, tf.float32)

    if mode == tf.estimator.ModeKeys.TRAIN:
        # Use our patched version until mtf updates theirs
        host_call = create_host_call(params['model_path'])
        mtf.utils.remove_summaries()

        # Creates train_op
        tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
        tf_update_ops.append(tf.assign_add(
            global_step, 1))  # Need to manually increment global_step
        tf.logging.info(f"tf_update_ops: {tf_update_ops}")
        train_op = tf.group(tf_update_ops)
    else:
        tf_mean_logits = lowering.export_to_tf_tensor(
            fully_replicated_mean_logits)
        tf_max_logits = lowering.export_to_tf_tensor(
            fully_replicated_max_logits)
        tf_loss_batch = tf.to_float(
            lowering.export_to_tf_tensor(fully_replicated_loss_batch))

    with mtf.utils.outside_all_rewrites():
        # Copy master variables to slices. Must be called first.
        restore_hook = mtf.MtfRestoreHook(lowering)
        if mode == tf.estimator.ModeKeys.TRAIN:
            # Set up the checkpoint server and return the TPUEstimatorSpec
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=10,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                params["model_path"],
                save_steps=params["steps_per_checkpoint"],
                saver=saver,
                listeners=[saver_listener])

            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                host_call=host_call,
                train_op=train_op,
                training_hooks=[restore_hook, saver_hook])

        elif mode == tf.estimator.ModeKeys.EVAL:
            # Evaluation metrics
            def _perplexity(loss):
                perplexity = tf.exp(loss)
                return tf.metrics.mean(perplexity)

            def _bits_per_byte(loss):
                bpb = loss * (0.29335 / math.log(2))
                return tf.metrics.mean(bpb)

            def _metric_fn(tf_mean_logits, tf_loss_batch):
                mean_logits = tf.metrics.mean(tf_mean_logits)
                loss = tf.reduce_mean(tf_loss_batch)
                perp = _perplexity(loss)
                bpb = _bits_per_byte(loss)
                return {
                    "mean_logits": mean_logits,
                    "perplexity": perp,
                    "bits per byte": bpb
                }

            def _lambada_metric_fn(labels, tf_max_logits, tf_loss_batch):
                eos_token = params["eos_id"]
                answer_positions = tf.where(
                    tf.math.not_equal(labels, eos_token))

                correct_answers = tf.gather_nd(
                    tf.math.equal(tf_max_logits, labels), answer_positions)
                accuracy = tf.metrics.mean(tf.cast(correct_answers,
                                                   tf.float32))

                # I guess tf_loss_batch has z_loss and maybe other stuff added to it
                # so maybe this should be calculated separately in the future
                answer_loss = tf.gather_nd(tf_loss_batch, answer_positions)
                log_perplexity = tf.metrics.mean(answer_loss)

                return {
                    "lambada_acc": accuracy,
                    "lambada_log_ppl": log_perplexity
                }

            eval_task = params["eval_task"]
            if eval_task == "lambada":
                eval_metrics = (_lambada_metric_fn,
                                [labels, tf_max_logits, tf_loss_batch])
            else:
                eval_metrics = (_metric_fn, [tf_mean_logits, tf_loss_batch])

            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.EVAL,
                evaluation_hooks=[restore_hook],
                loss=tf_loss,
                eval_metrics=eval_metrics)
Example #29
0
    def estimator_model_fn(cls,
                           hparams,
                           features,
                           labels,
                           mode,
                           config=None,
                           params=None,
                           decode_hparams=None):
        hparams = copy.deepcopy(hparams)
        use_tpu = params and params.get("use_tpu", False)
        hparams.use_tpu = use_tpu
        # merge decode_hparams into hparams if present
        if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None:
            for k, v in six.iteritems(decode_hparams.values()):
                if hasattr(hparams, k) and getattr(hparams, k) != v:
                    tf.logging.warning(
                        "Overriding hparams.%s with %s from decode_hparams" %
                        (k, v))
                setattr(hparams, k, v)

        # Instantiate model
        data_parallelism = None
        if not use_tpu and config:
            data_parallelism = config.data_parallelism
        model = cls(hparams,
                    mode,
                    data_parallelism=data_parallelism,
                    decode_hparams=decode_hparams)

        global_step = tf.train.get_global_step()

        mesh_shape = mtf.convert_to_shape(hparams.mesh_shape)
        layout_rules = mtf.convert_to_layout_rules(hparams.layout)
        if use_tpu:
            ctx = params["context"]
            num_hosts = ctx.num_hosts
            host_placement_fn = ctx.tpu_host_placement_function
            device_list = [
                host_placement_fn(host_id=t) for t in range(num_hosts)
            ]
            # TODO(ylc): Better estimation of replica cache size?
            replica_cache_size = 300 * 1000000  # 300M per replica
            # Worker 0 caches all the TPU binaries.
            worker0_mem = replica_cache_size * ctx.num_replicas
            devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
            var_placer = mtf.utils.BalancedVariablePlacer(
                device_list, devices_memeory_usage)
            mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape, layout_rules, mesh_devices, ctx.device_assignment)
        else:
            var_placer = None
            if len(data_parallelism.ps_devices) == 1:
                mesh_devices = [""] * mesh_shape.size
            else:
                assert len(data_parallelism.ps_devices) == mesh_shape.size
                mesh_devices = data_parallelism.ps_devices
            mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, layout_rules, mesh_devices)

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh", var_placer)
        # PREDICT mode
        if mode == tf.estimator.ModeKeys.PREDICT:
            return model.estimator_spec_predict(features, mesh, mesh_impl,
                                                use_tpu)

        logits, loss = model.mtf_model_fn(features, mesh)
        if use_tpu and logits is not None:
            logits = mtf.anonymize(logits)

        # TRAIN mode
        if mode == tf.estimator.ModeKeys.TRAIN:
            var_grads = mtf.gradients(
                [loss], [v.outputs[0] for v in graph.trainable_variables])
            lr = learning_rate.learning_rate_schedule(hparams)
            tf.summary.scalar("learning_rate", lr)
            mtf_lr = mtf.import_tf_tensor(
                mesh, tf.convert_to_tensor(lr, dtype=tf.float32),
                mtf.Shape([]))
            optimizer = mtf.optimize.make_optimizer(hparams, mtf_lr)
            update_ops = []
            for grad, var in zip(var_grads, graph.trainable_variables):
                update_ops.extend(optimizer.apply_grad(grad, var))

        lowering = mtf.Lowering(graph, {mesh: mesh_impl})

        tf_loss = lowering.export_to_tf_tensor(loss)
        tf_loss = tf.to_float(tf_loss)
        if logits and mode != tf.estimator.ModeKeys.TRAIN:
            tf_logits = lowering.export_to_tf_tensor(logits)

        if mode == tf.estimator.ModeKeys.TRAIN:
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            # tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
            train_op = tf.group(tf_update_ops)

        with mtf.utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=10,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                hparams.model_dir,
                save_steps=1000,
                saver=saver,
                listeners=[saver_listener])

        # EVAL mode
        if mode == tf.estimator.ModeKeys.EVAL:
            tf_logits = lowering.export_to_tf_tensor(logits)
            return model.estimator_spec_eval(features, tf_logits, labels,
                                             tf_loss, restore_hook, use_tpu)

        if use_tpu:
            # TPU host call. Important: need to be called before remove_summaries()
            if hparams.tpu_enable_host_call:
                host_call = t2t_model.create_host_call(hparams.model_dir)
            else:
                host_call = None

            t2t_model.remove_summaries()
            return tpu_estimator.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                train_op=train_op,
                host_call=host_call,
                training_hooks=[restore_hook, saver_hook])
        else:
            return tf.estimator.EstimatorSpec(
                tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                train_op=train_op,
                training_chief_hooks=[restore_hook, saver_hook])
def maybe_reshape_attention_input_for_2d_sharding(
    context, q, k, v, bias, unsplittable_dims):
  """Reshape the inputs to attention to split over an unused mesh dimension.

  In the case where the attention computation is unnecessarily replicated,
  this function reshapes the attention inputs to remove the unnecessary
  replication.

  This becomes relevent when doing 2-dimenional model parallelism.
  d_model is sharded over one mesh dimension and [vocab, num_heads, d_ff] are
  sharded over the other mesh dimension.  This fully distributes all of the
  einsum operations, except for the internals of the attention computation.

  To distribute that computation, this function creates a new tensor-dimension
  from the low bits of either the batch dimension or the num_heads dimension,
  and then splits that dimension over the unused mesh dimension.

  Args:
    context: a transformer.Context
    q: a Tensor
    k: a Tensor
    v: a Tensor
    bias: a Tensor
    unsplittable_dims: a list of tensor-dimensions not to split.  The key/value
      dimensions should be passed here.
  Returns:
    reshaped_q: a Tensor
    reshaped_k: a Tensor
    reshaped_v: a Tensor
    reshaped_bias: a Tensor
  """
  original_inputs = q, k, v, bias
  # we need to know the layout and mesh-shape to figure out what to do.
  if not context or not context.model.layout or not context.model.mesh_shape:
    return original_inputs
  mesh_shape = mtf.convert_to_shape(context.model.mesh_shape)
  layout_rules = mtf.convert_to_layout_rules(context.model.layout)
  # find a mesh dim that is unused (no tensor-dimension is split across it)
  mesh_axis_used = [False] * mesh_shape.ndims
  for x in original_inputs:
    for mesh_axis in layout_rules.tensor_layout(
        x.shape, mesh_shape).tensor_axis_to_mesh_axis:
      if mesh_axis is not None:
        mesh_axis_used[mesh_axis] = True
  if False not in mesh_axis_used:
    return original_inputs
  mesh_dim = mesh_shape.dims[mesh_axis_used.index(False)]
  # Choose an appropriate name for the new tensor-dimension so that the layout
  #   will know to split it across the unused mesh dimension.
  tensor_dim_name = None
  tensor_dim_name = layout_rules.mesh_dimension_name_to_tensor_dimension_names(
      mesh_dim.name)
  if tensor_dim_name:
    tensor_dim_name = tensor_dim_name[0]
  else:
    return original_inputs
  # Find a tensor-dimension that we can further split, by breaking off the
  # lower bits into our new tensor-dimension.
  # This resplittable tensor-dimension must be presnent in all of q, k, v
  #   and must be large enough to be further split.
  resplittable_dim = None
  for d in q.shape.dims:
    if d in k.shape.dims and d in v.shape.dims and d not in unsplittable_dims:
      num_splits = mtf.tensor_dim_to_mesh_dim_size(
          context.model.layout, context.model.mesh_shape, d)
      if d.size % (num_splits * mesh_dim.size) == 0:
        resplittable_dim = d
        break
  if not resplittable_dim:
    return original_inputs
  new_dim_high = mtf.Dimension(resplittable_dim.name, num_splits)
  new_dim_low = mtf.Dimension(tensor_dim_name,
                              resplittable_dim.size // num_splits)
  def _my_reshape(x):
    if x and resplittable_dim in x.shape.dims:
      return mtf.replace_dimensions(
          x, resplittable_dim, [new_dim_high, new_dim_low])
    else:
      return x
  return _my_reshape(q), _my_reshape(k), _my_reshape(v), _my_reshape(bias)
Example #31
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        # MTF setup.
        graph = mtf.Graph()
        mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
        layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)

        ctx = params["context"]
        num_hosts = ctx.num_hosts
        host_placement_fn = ctx.tpu_host_placement_function
        device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
        tf.logging.info("device_list = %s" % device_list, )
        replica_cache_size = 300 * 1000000  # 300M per replica
        # Worker 0 caches all the TPU binaries.
        worker0_mem = replica_cache_size * ctx.num_replicas
        devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
        var_placer = mtf.utils.BalancedVariablePlacer(device_list,
                                                      devices_memeory_usage)
        mesh_devices = [""] * mesh_shape.size
        physical_shape = list(ctx.device_assignment.topology.mesh_shape)
        logical_to_physical = mtf.simd_mesh_impl.auto_logical_to_physical_tpu(
            mesh_shape.to_integer_list, physical_shape)
        mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
            mesh_shape,
            layout_rules,
            mesh_devices,
            ctx.device_assignment,
            logical_to_physical=logical_to_physical)
        mesh = mtf.Mesh(graph, "bert_mesh", var_placer)

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        masked_lm_positions = features["masked_lm_positions"]
        masked_lm_ids = features["masked_lm_ids"]
        masked_lm_weights = features["masked_lm_weights"]
        next_sentence_labels = tf.squeeze(features["next_sentence_labels"], 1)

        batch_size = input_ids.get_shape()[0].value
        batch_dim = mtf.Dimension("batch", batch_size)

        seq_length = input_ids.get_shape()[1].value
        seq_dim = mtf.Dimension("seq", seq_length)
        max_predictions_per_seq = masked_lm_positions.get_shape()[1].value
        max_predictions_per_seq_dim = mtf.Dimension("max_pred_seq",
                                                    max_predictions_per_seq)

        mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids,
                                             [batch_dim, seq_dim])
        mtf_input_mask = mtf.import_tf_tensor(mesh, input_mask,
                                              [batch_dim, seq_dim])
        mtf_segment_ids = mtf.import_tf_tensor(mesh, segment_ids,
                                               [batch_dim, seq_dim])
        mtf_masked_lm_positions = mtf.import_tf_tensor(
            mesh, masked_lm_positions,
            [batch_dim, max_predictions_per_seq_dim])
        mtf_masked_lm_ids = mtf.import_tf_tensor(
            mesh, masked_lm_ids, [batch_dim, max_predictions_per_seq_dim])

        mtf_masked_lm_weights = mtf.import_tf_tensor(
            mesh, masked_lm_weights, [batch_dim, max_predictions_per_seq_dim])
        mtf_next_sentence_labels = mtf.import_tf_tensor(
            mesh, next_sentence_labels, [batch_dim])

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        model = bert_lib.BertModel(config=bert_config,
                                   is_training=is_training,
                                   input_ids=mtf_input_ids,
                                   input_mask=mtf_input_mask,
                                   token_type_ids=mtf_segment_ids,
                                   layout=layout_rules,
                                   mesh_shape=mesh_shape)

        (masked_lm_loss, masked_lm_example_loss,
         masked_lm_logits) = model.get_masked_lm_output(
             mtf_masked_lm_positions, mtf_masked_lm_ids, mtf_masked_lm_weights)

        (next_sentence_loss, next_sentence_example_loss, next_sentence_logits
         ) = model.get_next_sentence_output(mtf_next_sentence_labels)

        extra_loss = model.get_extra_loss()

        total_loss = masked_lm_loss + next_sentence_loss
        total_loss = mtf.anonymize(total_loss)
        masked_lm_example_loss = mtf.anonymize(masked_lm_example_loss)
        masked_lm_logits = mtf.anonymize(masked_lm_logits)
        next_sentence_example_loss = mtf.anonymize(next_sentence_example_loss)
        next_sentence_logits = mtf.anonymize(next_sentence_logits)

        # TRAIN mode
        if mode == tf.estimator.ModeKeys.TRAIN:
            _, update_ops = optimization_lib.create_optimizer(
                total_loss + extra_loss,
                learning_rate,
                num_train_steps,
                num_warmup_steps,
                optimizer=FLAGS.optimizer,
                clip_gradients=FLAGS.clip_gradients)

        lowering = mtf.Lowering(graph, {mesh: mesh_impl})

        tf_loss = tf.to_float(lowering.export_to_tf_tensor(total_loss))

        if mode == tf.estimator.ModeKeys.TRAIN:
            global_step = tf.train.get_global_step()
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
            train_op = tf.group(tf_update_ops)
        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(masked_lm_example_loss, masked_lm_logits,
                          masked_lm_ids, masked_lm_weights,
                          next_sentence_example_loss, next_sentence_logits,
                          next_sentence_labels):
                """Computes the loss and accuracy of the model."""
                masked_lm_logits = tf.reshape(masked_lm_logits,
                                              [-1, masked_lm_logits.shape[-1]])
                masked_lm_predictions = tf.argmax(masked_lm_logits,
                                                  axis=-1,
                                                  output_type=tf.int32)
                masked_lm_example_loss = tf.reshape(masked_lm_example_loss,
                                                    [-1])
                masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
                masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
                masked_lm_accuracy = tf.metrics.accuracy(
                    labels=masked_lm_ids,
                    predictions=masked_lm_predictions,
                    weights=masked_lm_weights)
                masked_lm_mean_loss = tf.metrics.mean(
                    values=masked_lm_example_loss, weights=masked_lm_weights)

                next_sentence_logits = tf.reshape(
                    next_sentence_logits, [-1, next_sentence_logits.shape[-1]])
                next_sentence_predictions = tf.argmax(next_sentence_logits,
                                                      axis=-1,
                                                      output_type=tf.int32)
                next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
                next_sentence_accuracy = tf.metrics.accuracy(
                    labels=next_sentence_labels,
                    predictions=next_sentence_predictions)
                next_sentence_mean_loss = tf.metrics.mean(
                    values=next_sentence_example_loss)

                return {
                    "masked_lm_accuracy": masked_lm_accuracy,
                    "masked_lm_loss": masked_lm_mean_loss,
                    "next_sentence_accuracy": next_sentence_accuracy,
                    "next_sentence_loss": next_sentence_mean_loss,
                }

            eval_metrics = (metric_fn, [
                lowering.export_to_tf_tensor(masked_lm_example_loss),
                lowering.export_to_tf_tensor(masked_lm_logits), masked_lm_ids,
                masked_lm_weights,
                lowering.export_to_tf_tensor(next_sentence_example_loss),
                lowering.export_to_tf_tensor(next_sentence_logits),
                next_sentence_labels
            ])

        with mtf.utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            if mode == tf.estimator.ModeKeys.TRAIN:
                saver = tf.train.Saver(tf.global_variables(),
                                       sharded=True,
                                       max_to_keep=10,
                                       keep_checkpoint_every_n_hours=2,
                                       defer_build=False,
                                       save_relative_paths=True)
                tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
                saver_listener = mtf.MtfCheckpointSaverListener(lowering)
                saver_hook = tf.train.CheckpointSaverHook(
                    FLAGS.output_dir,
                    save_steps=1000,
                    saver=saver,
                    listeners=[saver_listener])

                return tf.estimator.tpu.TPUEstimatorSpec(
                    tf.estimator.ModeKeys.TRAIN,
                    loss=tf_loss,
                    train_op=train_op,
                    training_hooks=[restore_hook, saver_hook])
            elif mode == tf.estimator.ModeKeys.EVAL:
                return tf.estimator.tpu.TPUEstimatorSpec(
                    tf.estimator.ModeKeys.EVAL,
                    evaluation_hooks=[restore_hook],
                    loss=tf_loss,
                    eval_metrics=eval_metrics)
Example #32
0
  def estimator_model_fn(cls,
                         hparams,
                         features,
                         labels,
                         mode,
                         config=None,
                         params=None,
                         decode_hparams=None,
                         use_tpu=False):
    hparams = copy.deepcopy(hparams)
    hparams.use_tpu = use_tpu
    # merge decode_hparams into hparams if present
    if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None:
      for k, v in six.iteritems(decode_hparams.values()):
        if hasattr(hparams, k) and getattr(hparams, k) != v:
          tf.logging.warning("Overriding hparams.%s with %s from decode_hparams"
                             % (k, v))
        setattr(hparams, k, v)

    # Instantiate model
    data_parallelism = None
    if not use_tpu and config:
      data_parallelism = config.data_parallelism
    model = cls(
        hparams,
        mode,
        data_parallelism=data_parallelism,
        decode_hparams=decode_hparams)

    global_step = tf.train.get_global_step()

    mesh_shape = mtf.convert_to_shape(hparams.mesh_shape)
    layout_rules = mtf.convert_to_layout_rules(hparams.layout)
    if use_tpu:
      ctx = params["context"]
      num_hosts = ctx.num_hosts
      host_placement_fn = ctx.tpu_host_placement_function
      device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
      # TODO(ylc): Better estimation of replica cache size?
      replica_cache_size = 300 * 1000000  # 300M per replica
      # Worker 0 caches all the TPU binaries.
      worker0_mem = replica_cache_size * ctx.num_replicas
      devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
      var_placer = mtf.utils.BalancedVariablePlacer(device_list,
                                                    devices_memeory_usage)
      mesh_devices = [""] * mesh_shape.size
      mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
          mesh_shape, layout_rules, mesh_devices, ctx.device_assignment)
    else:
      var_placer = None
      if data_parallelism is None or len(data_parallelism.ps_devices) == 1:
        mesh_devices = [""] * mesh_shape.size
      else:
        assert len(data_parallelism.ps_devices) == mesh_shape.size
        mesh_devices = data_parallelism.ps_devices
      mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
          mesh_shape, layout_rules, mesh_devices)

    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh", var_placer)
    # PREDICT mode
    if mode == tf.estimator.ModeKeys.PREDICT:
      return model.estimator_spec_predict(features, mesh, mesh_impl, use_tpu)

    logits, loss = model.mtf_model_fn(features, mesh)
    if use_tpu and logits is not None:
      logits = mtf.anonymize(logits)

    # TRAIN mode
    if mode == tf.estimator.ModeKeys.TRAIN:
      var_grads = mtf.gradients(
          [loss], [v.outputs[0] for v in graph.trainable_variables])
      lr = learning_rate.learning_rate_schedule(hparams)
      tf.summary.scalar("learning_rate", lr)
      mtf_lr = mtf.import_tf_tensor(
          mesh, tf.convert_to_tensor(lr, dtype=tf.float32), mtf.Shape([]))
      optimizer = mtf.optimize.make_optimizer(hparams, mtf_lr)
      update_ops = []
      for grad, var in zip(var_grads, graph.trainable_variables):
        update_ops.extend(optimizer.apply_grad(grad, var))

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})

    tf_loss = lowering.export_to_tf_tensor(loss)
    tf_loss = tf.to_float(tf_loss)
    if logits and mode != tf.estimator.ModeKeys.TRAIN:
      tf_logits = lowering.export_to_tf_tensor(logits)

    if mode == tf.estimator.ModeKeys.TRAIN:
      tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
      tf_update_ops.append(tf.assign_add(global_step, 1))
      # tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
      train_op = tf.group(tf_update_ops)

    with mtf.utils.outside_all_rewrites():
      # Copy master variables to slices. Must be called first.
      restore_hook = mtf.MtfRestoreHook(lowering)
      saver = tf.train.Saver(
          tf.global_variables(),
          sharded=True,
          max_to_keep=10,
          keep_checkpoint_every_n_hours=2,
          defer_build=False,
          save_relative_paths=True)
      tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
      saver_listener = mtf.MtfCheckpointSaverListener(lowering)
      saver_hook = tf.train.CheckpointSaverHook(
          hparams.model_dir,
          save_steps=1000,
          saver=saver,
          listeners=[saver_listener])

    # EVAL mode
    if mode == tf.estimator.ModeKeys.EVAL:
      tf_logits = lowering.export_to_tf_tensor(logits)
      return model.estimator_spec_eval(features, tf_logits, labels, tf_loss,
                                       restore_hook, use_tpu)

    if use_tpu:
      # TPU host call. Important: need to be called before remove_summaries()
      if hparams.tpu_enable_host_call:
        host_call = t2t_model.create_host_call(hparams.model_dir)
      else:
        host_call = None

      t2t_model.remove_summaries()
      return tpu_estimator.TPUEstimatorSpec(
          mode=tf.estimator.ModeKeys.TRAIN,
          loss=tf_loss,
          train_op=train_op,
          host_call=host_call,
          training_hooks=[restore_hook, saver_hook])
    else:
      return tf.estimator.EstimatorSpec(
          tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op,
          training_chief_hooks=[restore_hook, saver_hook])