Пример #1
0
            def computation_fn():
                graph = mtf.Graph()
                mesh = mtf.Mesh(graph, 'my_mesh')
                mesh_shape = mtf.convert_to_shape('all:2')
                layout = 'none:all'
                mesh_devices = [''] * mesh_shape.size
                mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                    mesh_shape, mtf.convert_to_layout_rules(layout),
                    mesh_devices, device_assignment)
                hidden_dim = mtf.Dimension('hidden', 3)
                w = mtf.get_variable(mesh,
                                     'w',
                                     shape=[hidden_dim],
                                     initializer=tf.constant_initializer(
                                         [0.1, -0.2, -0.1]))
                x = mtf.constant(mesh, [0.4, 0.2, -0.5], [hidden_dim],
                                 dtype=tf.float32)
                loss = mtf.reduce_mean(mtf.square(x - w))

                lr, update_ops = optimization_lib.create_optimizer(
                    loss, 0.2, 100, 10)
                self.lowering = mtf.Lowering(graph, {mesh: mesh_impl})

                tf_update_ops = [
                    self.lowering.lowered_operation(op) for op in update_ops
                ]
                tf_update_ops.append(
                    tf.assign_add(tf.train.get_or_create_global_step(), 1))
                train_op = tf.group(tf_update_ops)

                return lr, train_op
Пример #2
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        # MTF setup.
        graph = mtf.Graph()
        mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
        layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)

        ctx = params["context"]
        num_hosts = ctx.num_hosts
        host_placement_fn = ctx.tpu_host_placement_function
        device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
        tf.logging.info("device_list = %s" % device_list, )
        replica_cache_size = 300 * 1000000  # 300M per replica
        # Worker 0 caches all the TPU binaries.
        worker0_mem = replica_cache_size * ctx.num_replicas
        devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
        var_placer = mtf.utils.BalancedVariablePlacer(device_list,
                                                      devices_memeory_usage)
        mesh_devices = [""] * mesh_shape.size
        physical_shape = list(ctx.device_assignment.topology.mesh_shape)
        logical_to_physical = mtf.simd_mesh_impl.auto_logical_to_physical_tpu(
            mesh_shape.to_integer_list, physical_shape)
        mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
            mesh_shape,
            layout_rules,
            mesh_devices,
            ctx.device_assignment,
            logical_to_physical=logical_to_physical)
        mesh = mtf.Mesh(graph, "bert_mesh", var_placer)

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        masked_lm_positions = features["masked_lm_positions"]
        masked_lm_ids = features["masked_lm_ids"]
        masked_lm_weights = features["masked_lm_weights"]
        next_sentence_labels = tf.squeeze(features["next_sentence_labels"], 1)

        batch_size = input_ids.get_shape()[0].value
        batch_dim = mtf.Dimension("batch", batch_size)

        seq_length = input_ids.get_shape()[1].value
        seq_dim = mtf.Dimension("seq", seq_length)
        max_predictions_per_seq = masked_lm_positions.get_shape()[1].value
        max_predictions_per_seq_dim = mtf.Dimension("max_pred_seq",
                                                    max_predictions_per_seq)

        mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids,
                                             [batch_dim, seq_dim])
        mtf_input_mask = mtf.import_tf_tensor(mesh, input_mask,
                                              [batch_dim, seq_dim])
        mtf_segment_ids = mtf.import_tf_tensor(mesh, segment_ids,
                                               [batch_dim, seq_dim])
        mtf_masked_lm_positions = mtf.import_tf_tensor(
            mesh, masked_lm_positions,
            [batch_dim, max_predictions_per_seq_dim])
        mtf_masked_lm_ids = mtf.import_tf_tensor(
            mesh, masked_lm_ids, [batch_dim, max_predictions_per_seq_dim])

        mtf_masked_lm_weights = mtf.import_tf_tensor(
            mesh, masked_lm_weights, [batch_dim, max_predictions_per_seq_dim])
        mtf_next_sentence_labels = mtf.import_tf_tensor(
            mesh, next_sentence_labels, [batch_dim])

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        model = bert_lib.BertModel(config=bert_config,
                                   is_training=is_training,
                                   input_ids=mtf_input_ids,
                                   input_mask=mtf_input_mask,
                                   token_type_ids=mtf_segment_ids,
                                   layout=layout_rules,
                                   mesh_shape=mesh_shape)

        (masked_lm_loss, masked_lm_example_loss,
         masked_lm_logits) = model.get_masked_lm_output(
             mtf_masked_lm_positions, mtf_masked_lm_ids, mtf_masked_lm_weights)

        (next_sentence_loss, next_sentence_example_loss, next_sentence_logits
         ) = model.get_next_sentence_output(mtf_next_sentence_labels)

        extra_loss = model.get_extra_loss()

        total_loss = masked_lm_loss + next_sentence_loss
        total_loss = mtf.anonymize(total_loss)
        masked_lm_example_loss = mtf.anonymize(masked_lm_example_loss)
        masked_lm_logits = mtf.anonymize(masked_lm_logits)
        next_sentence_example_loss = mtf.anonymize(next_sentence_example_loss)
        next_sentence_logits = mtf.anonymize(next_sentence_logits)

        # TRAIN mode
        if mode == tf.estimator.ModeKeys.TRAIN:
            _, update_ops = optimization_lib.create_optimizer(
                total_loss + extra_loss,
                learning_rate,
                num_train_steps,
                num_warmup_steps,
                optimizer=FLAGS.optimizer,
                clip_gradients=FLAGS.clip_gradients)

        lowering = mtf.Lowering(graph, {mesh: mesh_impl})

        tf_loss = tf.to_float(lowering.export_to_tf_tensor(total_loss))

        if mode == tf.estimator.ModeKeys.TRAIN:
            global_step = tf.train.get_global_step()
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
            train_op = tf.group(tf_update_ops)
        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(masked_lm_example_loss, masked_lm_logits,
                          masked_lm_ids, masked_lm_weights,
                          next_sentence_example_loss, next_sentence_logits,
                          next_sentence_labels):
                """Computes the loss and accuracy of the model."""
                masked_lm_logits = tf.reshape(masked_lm_logits,
                                              [-1, masked_lm_logits.shape[-1]])
                masked_lm_predictions = tf.argmax(masked_lm_logits,
                                                  axis=-1,
                                                  output_type=tf.int32)
                masked_lm_example_loss = tf.reshape(masked_lm_example_loss,
                                                    [-1])
                masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
                masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
                masked_lm_accuracy = tf.metrics.accuracy(
                    labels=masked_lm_ids,
                    predictions=masked_lm_predictions,
                    weights=masked_lm_weights)
                masked_lm_mean_loss = tf.metrics.mean(
                    values=masked_lm_example_loss, weights=masked_lm_weights)

                next_sentence_logits = tf.reshape(
                    next_sentence_logits, [-1, next_sentence_logits.shape[-1]])
                next_sentence_predictions = tf.argmax(next_sentence_logits,
                                                      axis=-1,
                                                      output_type=tf.int32)
                next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
                next_sentence_accuracy = tf.metrics.accuracy(
                    labels=next_sentence_labels,
                    predictions=next_sentence_predictions)
                next_sentence_mean_loss = tf.metrics.mean(
                    values=next_sentence_example_loss)

                return {
                    "masked_lm_accuracy": masked_lm_accuracy,
                    "masked_lm_loss": masked_lm_mean_loss,
                    "next_sentence_accuracy": next_sentence_accuracy,
                    "next_sentence_loss": next_sentence_mean_loss,
                }

            eval_metrics = (metric_fn, [
                lowering.export_to_tf_tensor(masked_lm_example_loss),
                lowering.export_to_tf_tensor(masked_lm_logits), masked_lm_ids,
                masked_lm_weights,
                lowering.export_to_tf_tensor(next_sentence_example_loss),
                lowering.export_to_tf_tensor(next_sentence_logits),
                next_sentence_labels
            ])

        with mtf.utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            if mode == tf.estimator.ModeKeys.TRAIN:
                saver = tf.train.Saver(tf.global_variables(),
                                       sharded=True,
                                       max_to_keep=10,
                                       keep_checkpoint_every_n_hours=2,
                                       defer_build=False,
                                       save_relative_paths=True)
                tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
                saver_listener = mtf.MtfCheckpointSaverListener(lowering)
                saver_hook = tf.train.CheckpointSaverHook(
                    FLAGS.output_dir,
                    save_steps=1000,
                    saver=saver,
                    listeners=[saver_listener])

                return tf.estimator.tpu.TPUEstimatorSpec(
                    tf.estimator.ModeKeys.TRAIN,
                    loss=tf_loss,
                    train_op=train_op,
                    training_hooks=[restore_hook, saver_hook])
            elif mode == tf.estimator.ModeKeys.EVAL:
                return tf.estimator.tpu.TPUEstimatorSpec(
                    tf.estimator.ModeKeys.EVAL,
                    evaluation_hooks=[restore_hook],
                    loss=tf_loss,
                    eval_metrics=eval_metrics)
Пример #3
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        # MTF setup.
        graph = mtf.Graph()
        mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
        layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)

        ctx = params["context"]
        num_hosts = ctx.num_hosts
        host_placement_fn = ctx.tpu_host_placement_function
        device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
        tf.logging.info("device_list = %s" % device_list, )
        replica_cache_size = 300 * 1000000  # 300M per replica
        # Worker 0 caches all the TPU binaries.
        worker0_mem = replica_cache_size * ctx.num_replicas
        devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
        var_placer = mtf.utils.BalancedVariablePlacer(device_list,
                                                      devices_memeory_usage)
        mesh_devices = [""] * mesh_shape.size
        mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules,
                                                    mesh_devices,
                                                    ctx.device_assignment)
        mesh = mtf.Mesh(graph, "bert_mesh", var_placer)

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        label_ids = features["label_ids"]
        is_real_example = None
        if "is_real_example" in features:
            is_real_example = tf.cast(features["is_real_example"],
                                      dtype=tf.float32)
        else:
            is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)

        batch_size = input_ids.get_shape()[0].value
        batch_dim = mtf.Dimension("batch", batch_size)
        seq_length = input_ids.get_shape()[1].value
        seq_dim = mtf.Dimension("seq", seq_length)
        num_labels_dim = mtf.Dimension("seq", num_labels)
        mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids,
                                             [batch_dim, seq_dim])
        mtf_input_mask = mtf.import_tf_tensor(mesh, input_mask,
                                              [batch_dim, seq_dim])
        mtf_segment_ids = mtf.import_tf_tensor(mesh, segment_ids,
                                               [batch_dim, seq_dim])
        mtf_label_ids = mtf.import_tf_tensor(mesh, label_ids, [batch_dim])

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        (total_loss, per_example_loss, logits,
         probabilities) = create_model(bert_config, is_training, mtf_input_ids,
                                       mtf_input_mask, mtf_segment_ids,
                                       mtf_label_ids, num_labels_dim,
                                       layout_rules, mesh_shape)
        total_loss = mtf.anonymize(total_loss)
        per_example_loss = mtf.anonymize(per_example_loss)
        logits = mtf.anonymize(logits)

        if mode == tf.estimator.ModeKeys.TRAIN:
            _, update_ops = optimization_lib.create_optimizer(
                total_loss,
                learning_rate,
                num_train_steps,
                num_warmup_steps,
                max_optimized_variable_size=FLAGS.max_optimized_variable_size,
                optimizer=FLAGS.optimizer,
                clip_gradients=FLAGS.clip_gradients)

        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        tf_loss = tf.to_float(lowering.export_to_tf_tensor(total_loss))

        if mode == tf.estimator.ModeKeys.TRAIN:
            global_step = tf.train.get_global_step()
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
            train_op = tf.group(tf_update_ops)
        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(per_example_loss, label_ids, logits,
                          is_real_example):
                predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
                accuracy = tf.metrics.accuracy(labels=label_ids,
                                               predictions=predictions,
                                               weights=is_real_example)
                loss = tf.metrics.mean(values=per_example_loss,
                                       weights=is_real_example)
                return {
                    "eval_accuracy": accuracy,
                    "eval_loss": loss,
                }

            eval_metrics = (metric_fn, [
                lowering.export_to_tf_tensor(per_example_loss), label_ids,
                lowering.export_to_tf_tensor(logits), is_real_example
            ])

        tvars = tf.trainable_variables()
        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (assignment_map, initialized_variable_names
             ) = bert_lib.get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        with mtf.utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            if mode == tf.estimator.ModeKeys.TRAIN:
                saver = tf.train.Saver(tf.global_variables(),
                                       sharded=True,
                                       max_to_keep=10,
                                       keep_checkpoint_every_n_hours=2,
                                       defer_build=False,
                                       save_relative_paths=True)
                tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
                saver_listener = mtf.MtfCheckpointSaverListener(lowering)
                saver_hook = tf.train.CheckpointSaverHook(
                    FLAGS.output_dir,
                    save_steps=1000,
                    saver=saver,
                    listeners=[saver_listener])

                return tf.estimator.tpu.TPUEstimatorSpec(
                    mode,
                    loss=tf_loss,
                    train_op=train_op,
                    training_hooks=[restore_hook, saver_hook],
                    scaffold_fn=scaffold_fn)
            elif mode == tf.estimator.ModeKeys.EVAL:
                return tf.estimator.tpu.TPUEstimatorSpec(
                    mode,
                    evaluation_hooks=[restore_hook],
                    loss=tf_loss,
                    eval_metrics=eval_metrics,
                    scaffold_fn=scaffold_fn)
            else:
                return tf.estimator.tpu.TPUEstimatorSpec(
                    mode,
                    prediction_hooks=[restore_hook],
                    predictions={
                        "probabilities":
                        lowering.export_to_tf_tensor(probabilities)
                    },
                    scaffold_fn=scaffold_fn)
Пример #4
0
  def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
    """The `model_fn` for TPUEstimator."""

    tf.logging.info("*** Features ***")
    for name in sorted(features.keys()):
      tf.logging.info("  name = %s, shape = %s" % (name, features[name].shape))

    # MTF setup.
    graph = mtf.Graph()
    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)

    ctx = params["context"]
    num_hosts = ctx.num_hosts
    host_placement_fn = ctx.tpu_host_placement_function
    device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
    tf.logging.info("device_list = %s" % device_list,)
    replica_cache_size = 300 * 1000000  # 300M per replica
    # Worker 0 caches all the TPU binaries.
    worker0_mem = replica_cache_size * ctx.num_replicas
    devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
    var_placer = mtf.utils.BalancedVariablePlacer(device_list,
                                                  devices_memeory_usage)
    mesh_devices = [""] * mesh_shape.size
    mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules,
                                                mesh_devices,
                                                ctx.device_assignment)
    mesh = mtf.Mesh(graph, "bert_mesh", var_placer)

    unique_ids = features["unique_ids"]
    input_ids = features["input_ids"]
    input_mask = features["input_mask"]
    segment_ids = features["segment_ids"]

    batch_size = input_ids.get_shape()[0].value
    batch_dim = mtf.Dimension("batch", batch_size)
    seq_length = input_ids.get_shape()[1].value
    seq_dim = mtf.Dimension("seq", seq_length)

    mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids, [batch_dim, seq_dim])
    mtf_input_mask = mtf.import_tf_tensor(mesh, input_mask,
                                          [batch_dim, seq_dim])
    mtf_segment_ids = mtf.import_tf_tensor(mesh, segment_ids,
                                           [batch_dim, seq_dim])

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)

    (start_logits, end_logits) = create_model(
        bert_config=bert_config,
        is_training=is_training,
        input_ids=mtf_input_ids,
        input_mask=mtf_input_mask,
        segment_ids=mtf_segment_ids)

    if mode == tf.estimator.ModeKeys.TRAIN:

      def compute_loss(logits, positions):
        one_hot_positions = mtf.one_hot(positions, output_dim=seq_dim)
        log_probs = mtf.log_softmax(logits, seq_dim)
        loss = -mtf.reduce_mean(
            mtf.reduce_sum(one_hot_positions * log_probs, reduced_dim=seq_dim))
        return loss

      start_positions = features["start_positions"]
      mtf_start_positions = mtf.import_tf_tensor(mesh, start_positions,
                                                 [batch_dim])
      end_positions = features["end_positions"]
      mtf_end_positions = mtf.import_tf_tensor(mesh, end_positions, [batch_dim])

      start_loss = compute_loss(start_logits, mtf_start_positions)
      end_loss = compute_loss(end_logits, mtf_end_positions)

      total_loss = (start_loss + end_loss) / 2.0
      _, update_ops = optimization_lib.create_optimizer(
          total_loss,
          learning_rate,
          num_train_steps,
          num_warmup_steps,
          max_optimized_variable_size=FLAGS.max_optimized_variable_size,
          optimizer=FLAGS.optimizer,
          clip_gradients=FLAGS.clip_gradients)
    elif mode == tf.estimator.ModeKeys.PREDICT:
      start_logits = mtf.anonymize(start_logits)
      end_logits = mtf.anonymize(end_logits)

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})

    if mode == tf.estimator.ModeKeys.TRAIN:
      tf_loss = tf.to_float(lowering.export_to_tf_tensor(total_loss))
      global_step = tf.train.get_global_step()
      tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
      tf_update_ops.append(tf.assign_add(global_step, 1))
      tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
      train_op = tf.group(tf_update_ops)

    tvars = tf.trainable_variables()
    initialized_variable_names = {}
    scaffold_fn = None
    if init_checkpoint:
      (assignment_map, initialized_variable_names
      ) = bert_lib.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
      if use_tpu:

        def tpu_scaffold():
          tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
          return tf.train.Scaffold()

        scaffold_fn = tpu_scaffold
      else:
        tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

    tf.logging.info("**** Trainable Variables ****")
    for var in tvars:
      init_string = ""
      if var.name in initialized_variable_names:
        init_string = ", *INIT_FROM_CKPT*"
      tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                      init_string)

    with mtf.utils.outside_all_rewrites():
      # Copy master variables to slices. Must be called first.
      restore_hook = mtf.MtfRestoreHook(lowering)
      if mode == tf.estimator.ModeKeys.TRAIN:
        saver = tf.train.Saver(
            tf.global_variables(),
            sharded=True,
            max_to_keep=10,
            keep_checkpoint_every_n_hours=2,
            defer_build=False,
            save_relative_paths=True)
        tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
        saver_listener = mtf.MtfCheckpointSaverListener(lowering)
        saver_hook = tf.train.CheckpointSaverHook(
            FLAGS.output_dir,
            save_steps=1000,
            saver=saver,
            listeners=[saver_listener])

        return tf.estimator.tpu.TPUEstimatorSpec(
            mode,
            loss=tf_loss,
            train_op=train_op,
            training_hooks=[restore_hook, saver_hook],
            scaffold_fn=scaffold_fn)
      elif mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            "unique_ids": unique_ids,
            "start_logits": lowering.export_to_tf_tensor(start_logits),
            "end_logits": lowering.export_to_tf_tensor(end_logits),
        }

        return tf.estimator.tpu.TPUEstimatorSpec(
            mode,
            prediction_hooks=[restore_hook],
            predictions=predictions,
            scaffold_fn=scaffold_fn)
      else:
        raise ValueError("Only TRAIN and PREDICT modes are supported: %s" %
                         (mode))