示例#1
0
    def test_get_laidout_tensors(self, is_eval_mode):
        mesh_shape = "mesh_x:2, mesh_y:1"
        layout = "batch:mesh_x, io:mesh_y"
        batch_io_dim = 4

        with tf.Session() as sess:
            topology, num_cores = self.initialize_system(sess)

            # Get a device_assignment object for mtf.
            d_assignment = device_assignment.device_assignment(
                topology, computation_shape=[1, 1, 1], num_replicas=num_cores)

            # Hacked dataset creator: creates different datasets for the first and
            # second call, in order to test SimdMeshImplInputReader.
            self.sub_batch_created_times = 0

            def stateful_ds_creator():
                whole_batch = tf.eye(batch_io_dim, dtype=tf.float32)
                sub_batch = tf.slice(whole_batch,
                                     [self.sub_batch_created_times * 2, 0],
                                     [2, 4])
                self.sub_batch_created_times += 1
                return tf.data.Dataset.from_tensors(
                    sub_batch).repeat().unbatch()

            batch_dim = mtf.Dimension("batch", batch_io_dim)
            io_dim = mtf.Dimension("io", batch_io_dim)
            mtf_input_shapes = [mtf.Shape([batch_dim, io_dim])]

            # Get mesh_impl.
            mesh_shape = mtf.convert_to_shape(mesh_shape)
            layout_rules = mtf.convert_to_layout_rules(layout)
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape, layout_rules, None, d_assignment)

            simd_input_reader = input_reader.SimdMeshImplInputReader(
                mesh_impl,
                stateful_ds_creator,
                mtf_input_shapes,
                external_worker=False,
                is_eval_mode=is_eval_mode)

            def model_fn(features):
                return features

            replicated_computation = tpu.replicate(
                computation=model_fn,
                inputs=[[]] * num_cores,
                infeed_queue=simd_input_reader.infeed_queue,
                device_assignment=d_assignment)

            simd_input_reader.start_infeed_thread(sess, 1)
            results = sess.run(replicated_computation)
            print("results: {}".format(results))

            core_0_data = results[0][0]
            core_1_data = results[1][0]
            print("core_0_data: {}".format(core_0_data))
            print("core_1_data: {}".format(core_1_data))

            if is_eval_mode:
                # If there is only one dataset object, then the stateful_ds_creator()
                # should be called only once.
                self.assertAllClose(
                    np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32),
                    core_0_data)
                self.assertAllClose(
                    np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32),
                    core_1_data)
            else:
                # If there are two dataset objects, then the stateful_ds_creator()
                # should be called twice.
                self.assertAllClose(
                    np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32),
                    core_0_data)
                self.assertAllClose(
                    np.array([[0, 0, 1, 0], [0, 0, 0, 1]], dtype=np.float32),
                    core_1_data)

            sess.run(tf.tpu.shutdown_system())
示例#2
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        # MTF setup.
        graph = mtf.Graph()
        mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
        layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)

        ctx = params["context"]
        num_hosts = ctx.num_hosts
        host_placement_fn = ctx.tpu_host_placement_function
        device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
        tf.logging.info("device_list = %s" % device_list, )
        replica_cache_size = 300 * 1000000  # 300M per replica
        # Worker 0 caches all the TPU binaries.
        worker0_mem = replica_cache_size * ctx.num_replicas
        devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
        var_placer = mtf.utils.BalancedVariablePlacer(device_list,
                                                      devices_memeory_usage)
        mesh_devices = [""] * mesh_shape.size
        mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules,
                                                    mesh_devices,
                                                    ctx.device_assignment)
        mesh = mtf.Mesh(graph, "bert_mesh", var_placer)

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        label_ids = features["label_ids"]
        is_real_example = None
        if "is_real_example" in features:
            is_real_example = tf.cast(features["is_real_example"],
                                      dtype=tf.float32)
        else:
            is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)

        batch_size = input_ids.get_shape()[0].value
        batch_dim = mtf.Dimension("batch", batch_size)
        seq_length = input_ids.get_shape()[1].value
        seq_dim = mtf.Dimension("seq", seq_length)
        num_labels_dim = mtf.Dimension("seq", num_labels)
        mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids,
                                             [batch_dim, seq_dim])
        mtf_input_mask = mtf.import_tf_tensor(mesh, input_mask,
                                              [batch_dim, seq_dim])
        mtf_segment_ids = mtf.import_tf_tensor(mesh, segment_ids,
                                               [batch_dim, seq_dim])
        mtf_label_ids = mtf.import_tf_tensor(mesh, label_ids, [batch_dim])

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        (total_loss, per_example_loss, logits,
         probabilities) = create_model(bert_config, is_training, mtf_input_ids,
                                       mtf_input_mask, mtf_segment_ids,
                                       mtf_label_ids, num_labels_dim,
                                       layout_rules, mesh_shape)
        total_loss = mtf.anonymize(total_loss)
        per_example_loss = mtf.anonymize(per_example_loss)
        logits = mtf.anonymize(logits)

        if mode == tf.estimator.ModeKeys.TRAIN:
            _, update_ops = optimization_lib.create_optimizer(
                total_loss,
                learning_rate,
                num_train_steps,
                num_warmup_steps,
                max_optimized_variable_size=FLAGS.max_optimized_variable_size,
                optimizer=FLAGS.optimizer,
                clip_gradients=FLAGS.clip_gradients)

        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        tf_loss = tf.to_float(lowering.export_to_tf_tensor(total_loss))

        if mode == tf.estimator.ModeKeys.TRAIN:
            global_step = tf.train.get_global_step()
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
            train_op = tf.group(tf_update_ops)
        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(per_example_loss, label_ids, logits,
                          is_real_example):
                predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
                accuracy = tf.metrics.accuracy(labels=label_ids,
                                               predictions=predictions,
                                               weights=is_real_example)
                loss = tf.metrics.mean(values=per_example_loss,
                                       weights=is_real_example)
                return {
                    "eval_accuracy": accuracy,
                    "eval_loss": loss,
                }

            eval_metrics = (metric_fn, [
                lowering.export_to_tf_tensor(per_example_loss), label_ids,
                lowering.export_to_tf_tensor(logits), is_real_example
            ])

        tvars = tf.trainable_variables()
        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (assignment_map, initialized_variable_names
             ) = bert_lib.get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        with mtf.utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            if mode == tf.estimator.ModeKeys.TRAIN:
                saver = tf.train.Saver(tf.global_variables(),
                                       sharded=True,
                                       max_to_keep=10,
                                       keep_checkpoint_every_n_hours=2,
                                       defer_build=False,
                                       save_relative_paths=True)
                tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
                saver_listener = mtf.MtfCheckpointSaverListener(lowering)
                saver_hook = tf.train.CheckpointSaverHook(
                    FLAGS.output_dir,
                    save_steps=1000,
                    saver=saver,
                    listeners=[saver_listener])

                return tf.estimator.tpu.TPUEstimatorSpec(
                    mode,
                    loss=tf_loss,
                    train_op=train_op,
                    training_hooks=[restore_hook, saver_hook],
                    scaffold_fn=scaffold_fn)
            elif mode == tf.estimator.ModeKeys.EVAL:
                return tf.estimator.tpu.TPUEstimatorSpec(
                    mode,
                    evaluation_hooks=[restore_hook],
                    loss=tf_loss,
                    eval_metrics=eval_metrics,
                    scaffold_fn=scaffold_fn)
            else:
                return tf.estimator.tpu.TPUEstimatorSpec(
                    mode,
                    prediction_hooks=[restore_hook],
                    predictions={
                        "probabilities":
                        lowering.export_to_tf_tensor(probabilities)
                    },
                    scaffold_fn=scaffold_fn)
示例#3
0
文件: cifar10.py 项目: mkrdip/alcf
def model_fn(features, labels, mode, params):
    """The model_fn argument for creating an Estimator."""
    tf.logging.info("features = %s labels = %s mode = %s params=%s" %
                    (features, labels, mode, params))
    global_step = tf.train.get_global_step()
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    logits, loss = cifar_model(features, labels, mesh)
    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)
    mesh_size = mesh_shape.size
    # To enable manual device placement (e.g.: GPU 1, 2) comment the line below and uncomment the next one
    mesh_devices = [""] * mesh_size
    # mesh_devices = ['GPU:' + str(i) for i in range(mesh_size)]
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        mesh_shape, layout_rules, mesh_devices)

    labels = features['label']

    if mode == tf.estimator.ModeKeys.TRAIN:
        var_grads = mtf.gradients(
            [loss], [v.outputs[0] for v in graph.trainable_variables])
        # Variables that affect learning rate.
        num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / FLAGS.batch_size
        decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY)

        # Decay the learning rate exponentially based on the number of steps.
        lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE,
                                        global_step,
                                        decay_steps,
                                        LEARNING_RATE_DECAY_FACTOR,
                                        staircase=True)
        tf.summary.scalar('learning_rate', lr)
        mtf_lr = mtf.import_tf_tensor(
            mesh, tf.convert_to_tensor(lr, dtype=tf.float32), mtf.Shape([]))
        optimizer = mtf.optimize.AdafactorOptimizer(learning_rate=mtf_lr)
        update_ops = optimizer.apply_grads(var_grads,
                                           graph.trainable_variables)

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    restore_hook = mtf.MtfRestoreHook(lowering)

    tf_logits = lowering.export_to_tf_tensor(logits)
    if mode != tf.estimator.ModeKeys.PREDICT:
        tf_loss = lowering.export_to_tf_tensor(loss)
        tf_loss = tf.to_float(tf_loss)
        tf.summary.scalar("loss", tf_loss)

    if mode == tf.estimator.ModeKeys.TRAIN:
        tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
        tf_update_ops.append(tf.assign_add(global_step, 1))
        train_op = tf.group(tf_update_ops)
        saver = tf.train.Saver(tf.global_variables(),
                               sharded=True,
                               max_to_keep=10,
                               keep_checkpoint_every_n_hours=2,
                               defer_build=False,
                               save_relative_paths=True)
        tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
        saver_listener = mtf.MtfCheckpointSaverListener(lowering)
        saver_hook = tf.train.CheckpointSaverHook(FLAGS.model_dir,
                                                  save_steps=1000,
                                                  saver=saver,
                                                  listeners=[saver_listener])

        accuracy = tf.metrics.accuracy(labels=labels,
                                       predictions=tf.argmax(tf_logits,
                                                             axis=1))

        # Name tensors to be logged with LoggingTensorHook.
        tf.identity(tf_loss, "cross_entropy")
        tf.identity(accuracy[1], name="train_accuracy")

        # Save accuracy scalar to Tensorboard output.
        tf.summary.scalar("train_accuracy", accuracy[1])

        # restore_hook must come before saver_hook
        return tf.estimator.EstimatorSpec(
            tf.estimator.ModeKeys.TRAIN,
            loss=tf_loss,
            train_op=train_op,
            training_chief_hooks=[restore_hook, saver_hook])

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            "classes": tf.argmax(tf_logits, axis=1),
            "probabilities": tf.nn.softmax(tf_logits),
        }
        return tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.PREDICT,
            predictions=predictions,
            prediction_hooks=[restore_hook],
            export_outputs={
                "classify": tf.estimator.export.PredictOutput(predictions)
            })

    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.EVAL,
            loss=tf_loss,
            evaluation_hooks=[restore_hook],
            eval_metric_ops={
                "accuracy":
                tf.metrics.accuracy(labels=labels,
                                    predictions=tf.argmax(tf_logits, axis=1)),
            })
示例#4
0
def model_fn(features, labels, mode, params):
    """A model is called by TpuEstimator."""
    del labels
    global_step = tf.train.get_global_step()
    graph = mtf.Graph()
    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)
    if FLAGS.use_tpu:
        ctx = params['context']
        num_hosts = ctx.num_hosts
        host_placement_fn = ctx.tpu_host_placement_function
        device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
        tf.logging.info('device_list = %s' % device_list, )
        # TODO(ylc): Better estimation of replica cache size?
        replica_cache_size = 300 * 1000000  # 300M per replica
        # Worker 0 caches all the TPU binaries.
        worker0_mem = replica_cache_size * ctx.num_replicas
        devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
        var_placer = mtf.utils.BalancedVariablePlacer(device_list,
                                                      devices_memeory_usage)
        mesh_devices = [''] * mesh_shape.size
        mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules,
                                                    mesh_devices,
                                                    ctx.device_assignment)
    else:
        var_placer = None
        mesh_devices = [''] * mesh_shape.size
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
            mesh_shape, layout_rules, mesh_devices)
    mesh = mtf.Mesh(graph, 'my_mesh', var_placer)

    with mtf.utils.outside_all_rewrites():
        logits, loss = toy_model(features, mesh)

    # TRAIN mode
    if mode == tf.estimator.ModeKeys.TRAIN:
        var_grads = mtf.gradients(
            [loss], [v.outputs[0] for v in graph.trainable_variables])
        if FLAGS.optimizer == 'Adafactor':
            optimizer = mtf.optimize.AdafactorOptimizer()
        else:
            assert FLAGS.optimizer == 'SGD'
            optimizer = mtf.optimize.SgdOptimizer(learning_rate=FLAGS.lr)
        update_ops = optimizer.apply_grads(var_grads,
                                           graph.trainable_variables)
    else:
        # for now, we can only export fully-replicated tensors.
        fully_replicated_logits = mtf.anonymize(logits)

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})

    tf_loss = tf.to_float(lowering.export_to_tf_tensor(loss))

    if mode == tf.estimator.ModeKeys.TRAIN:
        tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
        tf_update_ops.append(tf.assign_add(global_step, 1))
        tf.logging.info('tf_update_ops: {}'.format(tf_update_ops))
        train_op = tf.group(tf_update_ops)
    else:
        tf_logits = lowering.export_to_tf_tensor(fully_replicated_logits)

    with mtf.utils.outside_all_rewrites():
        # Copy master variables to slices. Must be called first.
        restore_hook = mtf.MtfRestoreHook(lowering)
        if mode == tf.estimator.ModeKeys.TRAIN:
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=10,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                FLAGS.model_dir,
                save_steps=1000,
                saver=saver,
                listeners=[saver_listener])

            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                train_op=train_op,
                training_hooks=[restore_hook, saver_hook])
        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(tf_logits):
                mean_logits = tf.metrics.mean(tf_logits)
                return {'mean_logits': mean_logits}

            eval_metrics = (metric_fn, [tf_logits])

            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.EVAL,
                evaluation_hooks=[restore_hook],
                loss=tf_loss,
                eval_metrics=eval_metrics)
示例#5
0
  def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
    """The `model_fn` for TPUEstimator."""

    tf.logging.info("*** Features ***")
    for name in sorted(features.keys()):
      tf.logging.info("  name = %s, shape = %s" % (name, features[name].shape))

    # MTF setup.
    graph = mtf.Graph()
    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)

    if FLAGS.use_tpu:
      ctx = params["context"]
      num_hosts = ctx.num_hosts
      host_placement_fn = ctx.tpu_host_placement_function
      device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
      tf.logging.info("device_list = %s" % device_list,)
      replica_cache_size = 300 * 1000000  # 300M per replica
      # Worker 0 caches all the TPU binaries.
      worker0_mem = replica_cache_size * ctx.num_replicas
      devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
      var_placer = mtf.utils.BalancedVariablePlacer(device_list,
                                                    devices_memeory_usage)
      mesh_devices = [""] * mesh_shape.size
      physical_shape = list(ctx.device_assignment.topology.mesh_shape)
      logical_to_physical = mtf.simd_mesh_impl.auto_logical_to_physical_tpu(
          mesh_shape.to_integer_list, physical_shape)
      mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
          mesh_shape,
          layout_rules,
          mesh_devices,
          ctx.device_assignment,
          logical_to_physical=logical_to_physical)
    else:
      mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
          mesh_shape, layout_rules, [""] * mesh_shape.size)
      var_placer = None

    mesh = mtf.Mesh(graph, "bert_mesh", var_placer)
    input_ids = features["input_ids"]
    input_mask = features["input_mask"]
    segment_ids = features["segment_ids"]
    masked_lm_positions = features["masked_lm_positions"]
    masked_lm_ids = features["masked_lm_ids"]
    masked_lm_weights = features["masked_lm_weights"]
    next_sentence_labels = tf.squeeze(features["next_sentence_labels"], 1)

    batch_size = input_ids.get_shape()[0].value
    batch_dim = mtf.Dimension("batch", batch_size)

    seq_length = input_ids.get_shape()[1].value
    seq_dim = mtf.Dimension("seq", seq_length)
    max_predictions_per_seq = masked_lm_positions.get_shape()[1].value
    max_predictions_per_seq_dim = mtf.Dimension("max_pred_seq",
                                                max_predictions_per_seq)

    mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids, [batch_dim, seq_dim])
    mtf_input_mask = mtf.import_tf_tensor(mesh, input_mask,
                                          [batch_dim, seq_dim])
    mtf_segment_ids = mtf.import_tf_tensor(mesh, segment_ids,
                                           [batch_dim, seq_dim])
    mtf_masked_lm_positions = mtf.import_tf_tensor(
        mesh, masked_lm_positions, [batch_dim, max_predictions_per_seq_dim])
    mtf_masked_lm_ids = mtf.import_tf_tensor(
        mesh, masked_lm_ids, [batch_dim, max_predictions_per_seq_dim])

    mtf_masked_lm_weights = mtf.import_tf_tensor(
        mesh, masked_lm_weights, [batch_dim, max_predictions_per_seq_dim])
    mtf_next_sentence_labels = mtf.import_tf_tensor(
        mesh, next_sentence_labels, [batch_dim])

    is_training = (mode == tf_estimator.ModeKeys.TRAIN)

    model = bert_lib.BertModel(
        config=bert_config,
        is_training=is_training,
        input_ids=mtf_input_ids,
        input_mask=mtf_input_mask,
        token_type_ids=mtf_segment_ids,
        layout=layout_rules,
        mesh_shape=mesh_shape)

    (masked_lm_loss, masked_lm_example_loss,
     masked_lm_logits) = model.get_masked_lm_output(
         mtf_masked_lm_positions, mtf_masked_lm_ids, mtf_masked_lm_weights)

    (next_sentence_loss, next_sentence_example_loss,
     next_sentence_logits) = model.get_next_sentence_output(
         mtf_next_sentence_labels)

    extra_loss = model.get_extra_loss()

    total_loss = masked_lm_loss + next_sentence_loss
    total_loss = mtf.anonymize(total_loss)
    masked_lm_example_loss = mtf.anonymize(masked_lm_example_loss)
    masked_lm_logits = mtf.anonymize(masked_lm_logits)
    next_sentence_example_loss = mtf.anonymize(next_sentence_example_loss)
    next_sentence_logits = mtf.anonymize(next_sentence_logits)

    # TRAIN mode
    if mode == tf_estimator.ModeKeys.TRAIN:
      _, update_ops = optimization_lib.create_optimizer(
          total_loss + extra_loss,
          learning_rate,
          num_train_steps,
          num_warmup_steps,
          optimizer=FLAGS.optimizer,
          clip_gradients=FLAGS.clip_gradients)

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})

    tf_loss = tf.to_float(lowering.export_to_tf_tensor(total_loss))

    if mode == tf_estimator.ModeKeys.TRAIN:
      global_step = tf.train.get_global_step()
      tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
      tf_update_ops.append(tf.assign_add(global_step, 1))
      tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
      train_op = tf.group(tf_update_ops)
    elif mode == tf_estimator.ModeKeys.EVAL:

      def metric_fn(masked_lm_example_loss, masked_lm_logits, masked_lm_ids,
                    masked_lm_weights, next_sentence_example_loss,
                    next_sentence_logits, next_sentence_labels):
        """Computes the loss and accuracy of the model."""
        masked_lm_logits = tf.reshape(masked_lm_logits,
                                      [-1, masked_lm_logits.shape[-1]])
        masked_lm_predictions = tf.argmax(
            masked_lm_logits, axis=-1, output_type=tf.int32)
        masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1])
        masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
        masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
        masked_lm_accuracy = tf.metrics.accuracy(
            labels=masked_lm_ids,
            predictions=masked_lm_predictions,
            weights=masked_lm_weights)
        masked_lm_mean_loss = tf.metrics.mean(
            values=masked_lm_example_loss, weights=masked_lm_weights)

        next_sentence_logits = tf.reshape(
            next_sentence_logits, [-1, next_sentence_logits.shape[-1]])
        next_sentence_predictions = tf.argmax(
            next_sentence_logits, axis=-1, output_type=tf.int32)
        next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
        next_sentence_accuracy = tf.metrics.accuracy(
            labels=next_sentence_labels, predictions=next_sentence_predictions)
        next_sentence_mean_loss = tf.metrics.mean(
            values=next_sentence_example_loss)

        return {
            "masked_lm_accuracy": masked_lm_accuracy,
            "masked_lm_loss": masked_lm_mean_loss,
            "next_sentence_accuracy": next_sentence_accuracy,
            "next_sentence_loss": next_sentence_mean_loss,
        }

      eval_metrics = (metric_fn, [
          lowering.export_to_tf_tensor(masked_lm_example_loss),
          lowering.export_to_tf_tensor(masked_lm_logits), masked_lm_ids,
          masked_lm_weights,
          lowering.export_to_tf_tensor(next_sentence_example_loss),
          lowering.export_to_tf_tensor(next_sentence_logits),
          next_sentence_labels
      ])

    with mtf.utils.outside_all_rewrites():
      # Copy master variables to slices. Must be called first.
      restore_hook = mtf.MtfRestoreHook(lowering)
      if mode == tf_estimator.ModeKeys.TRAIN:
        saver = tf.train.Saver(
            tf.global_variables(),
            sharded=True,
            max_to_keep=10,
            keep_checkpoint_every_n_hours=2,
            defer_build=False,
            save_relative_paths=True)
        tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
        saver_listener = mtf.MtfCheckpointSaverListener(lowering)
        saver_hook = tf.train.CheckpointSaverHook(
            FLAGS.output_dir,
            save_steps=1000,
            saver=saver,
            listeners=[saver_listener])

        return tf_estimator.tpu.TPUEstimatorSpec(
            tf_estimator.ModeKeys.TRAIN,
            loss=tf_loss,
            train_op=train_op,
            training_hooks=[restore_hook, saver_hook])
      elif mode == tf_estimator.ModeKeys.EVAL:
        return tf_estimator.tpu.TPUEstimatorSpec(
            tf_estimator.ModeKeys.EVAL,
            evaluation_hooks=[restore_hook],
            loss=tf_loss,
            eval_metrics=eval_metrics)
示例#6
0
文件: attention.py 项目: yanxg/mesh
def maybe_reshape_attention_input_for_2d_sharding(context, q, k, v, bias,
                                                  unsplittable_dims):
    """Reshape the inputs to attention to split over an unused mesh dimension.

  In the case where the attention computation is unnecessarily replicated,
  this function reshapes the attention inputs to remove the unnecessary
  replication.

  This becomes relevent when doing 2-dimenional model parallelism.
  d_model is sharded over one mesh dimension and [vocab, num_heads, d_ff] are
  sharded over the other mesh dimension.  This fully distributes all of the
  einsum operations, except for the internals of the attention computation.

  To distribute that computation, this function creates a new tensor-dimension
  from the low bits of either the batch dimension or the num_heads dimension,
  and then splits that dimension over the unused mesh dimension.

  Args:
    context: a transformer.Context
    q: a Tensor
    k: a Tensor
    v: a Tensor
    bias: a Tensor
    unsplittable_dims: a list of tensor-dimensions not to split.  The key/value
      dimensions should be passed here.
  Returns:
    reshaped_q: a Tensor
    reshaped_k: a Tensor
    reshaped_v: a Tensor
    reshaped_bias: a Tensor
  """
    original_inputs = q, k, v, bias
    # we need to know the layout and mesh-shape to figure out what to do.
    if not context or not context.model.layout or not context.model.mesh_shape:
        return original_inputs
    mesh_shape = mtf.convert_to_shape(context.model.mesh_shape)
    layout_rules = mtf.convert_to_layout_rules(context.model.layout)
    # find a mesh dim that is unused (no tensor-dimension is split across it)
    mesh_axis_used = [False] * mesh_shape.ndims
    for x in original_inputs:
        for mesh_axis in layout_rules.tensor_layout(
                x.shape, mesh_shape).tensor_axis_to_mesh_axis:
            if mesh_axis is not None:
                mesh_axis_used[mesh_axis] = True
    if False not in mesh_axis_used:
        return original_inputs
    mesh_dim = mesh_shape.dims[mesh_axis_used.index(False)]
    # Choose an appropriate name for the new tensor-dimension so that the layout
    #   will know to split it across the unused mesh dimension.
    tensor_dim_name = None
    tensor_dim_name = layout_rules.mesh_dimension_name_to_tensor_dimension_names(
        mesh_dim.name)
    if tensor_dim_name:
        tensor_dim_name = tensor_dim_name[0]
    else:
        return original_inputs
    # Find a tensor-dimension that we can further split, by breaking off the
    # lower bits into our new tensor-dimension.
    # This resplittable tensor-dimension must be presnent in all of q, k, v
    #   and must be large enough to be further split.
    resplittable_dim = None
    for d in q.shape.dims:
        if d in k.shape.dims and d in v.shape.dims and d not in unsplittable_dims:
            num_splits = mtf.tensor_dim_to_mesh_dim_size(
                context.model.layout, context.model.mesh_shape, d)
            if d.size % (num_splits * mesh_dim.size) == 0:
                resplittable_dim = d
                break
    if not resplittable_dim:
        return original_inputs
    new_dim_high = mtf.Dimension(resplittable_dim.name, num_splits)
    new_dim_low = mtf.Dimension(tensor_dim_name,
                                resplittable_dim.size // num_splits)

    def _my_reshape(x):
        if x and resplittable_dim in x.shape.dims:
            return mtf.replace_dimensions(x, resplittable_dim,
                                          [new_dim_high, new_dim_low])
        else:
            return x

    return _my_reshape(q), _my_reshape(k), _my_reshape(v), _my_reshape(bias)
示例#7
0
def main(_):

    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)
    print("mesh_shape : ", mesh_shape)
    print("layout_rules : ", layout_rules)
    print("FLAGS.gpus_per_node : ", FLAGS.gpus_per_node)
    print("FLAGS.gpus_per_task : ", FLAGS.gpus_per_task)
    print("FLAGS.tasks_per_node : ", FLAGS.tasks_per_node)

    # Resolve the cluster from SLURM environment
    cluster = tf.distribute.cluster_resolver.SlurmClusterResolver(
        {"mesh": mesh_shape.size // FLAGS.gpus_per_task},
        port_base=8822,
        gpus_per_node=FLAGS.gpus_per_node,
        gpus_per_task=FLAGS.gpus_per_task,
        tasks_per_node=FLAGS.tasks_per_node)

    cluster_spec = cluster.cluster_spec()
    # Create a server for all mesh members
    server = tf.distribute.Server(cluster_spec, "mesh", cluster.task_id)

    # Only he master job takes care of the graph building,
    # everyone else can just chill for now
    if cluster.task_id > 0:
        server.join()

    # Otherwise we are the main task, let's define the devices
    mesh_devices = [
        "/job:mesh/task:%d/device:GPU:%d" % (i, j)
        for i in range(cluster_spec.num_tasks("mesh"))
        for j in range(FLAGS.gpus_per_task)
    ]
    print("List of devices", mesh_devices)

    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "fft_mesh")

    # Build the model
    fft_err = benchmark_model(mesh)

    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        mesh_shape, layout_rules, mesh_devices)

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})

    # Retrieve output of computation
    result = lowering.export_to_tf_tensor(fft_err)
    print('Lowering done')

    with tf.Session(server.target) as sess:

        start = time.time()
        err = sess.run(result)
        end = time.time()

        niter = int(100 // np.log2(
            FLAGS.cube_size))  #since large meshes might take a lot of time
        start = time.time()
        for i in range(niter):
            err = sess.run(result)
        end = time.time()
        ttime = (end - start) / niter
        print('Time for ', mesh_shape, ' is : ', ttime)

        ###Uncomment this to get the output of a profiler
        ##    profiler = tf.profiler.Profiler(sess.graph)
        ##
        ##    run_meta = tf.RunMetadata()
        ##    err = sess.run(result,
        ##                   options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
        ##                   run_metadata=run_meta)
        ##
        ##    profiler.add_step(0, run_meta)
        ##
        ##    opts = (tf.profiler.ProfileOptionBuilder(
        ##        tf.profiler.ProfileOptionBuilder.time_and_memory())
        ##        .with_step(0)
        ##        .with_timeline_output(FLAGS.output_file).build())
        ##    profiler.profile_graph(options=opts)
        ##

        ###This is another way of profiling
        ##    profiler = tf.profiler.Profiler(sess.graph)
        ##
        ##    run_meta = tf.RunMetadata()
        ##    err = sess.run(result,
        ##                   options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
        ##                   run_metadata=run_meta)
        ##
        ##   # Create the Timeline object, and write it to a json
        ##    tl = timeline.Timeline(run_meta.step_stats)
        ##    ctf = tl.generate_chrome_trace_format()
        ##    with open('timelinev2-%d-%d.json'%(FLAGS.cube_size, FLAGS.max_depth), 'w') as f:
        ##        f.write(ctf)
        ##
        ##    profiler.add_step(0, run_meta)
        ##
        ##    logfile = str(FLAGS.output_file) + '-%d'%FLAGS.max_depth
        ##    opts = (tf.profiler.ProfileOptionBuilder(tf.profiler.ProfileOptionBuilder.time_and_memory())
        ##            .with_step(0)
        ##            .with_timeline_output(logfile)
        ##            .with_stdout_output()
        ##            .with_max_depth(FLAGS.max_depth).build()
        ##    )
        ##
        ##    profiler.profile_graph(options=opts)
        ##

        print("Max absolute FFT error %f, with wall time %f" % (err, ttime))
        exit(-1)
示例#8
0
def get_layout():
    return mtf.convert_to_layout_rules(FLAGS.layout)
示例#9
0
def dalle_model_fn(features, labels, mode, params):
    # since we can simply infer labels here based on the input - features here are the text input,
    # and labels are the image input
    global_step = tf.train.get_global_step()  # Get global step

    mode_str = mode_to_str(mode)

    # load vae in tensorflow graph before mtf
    vae, vae_checkpoint_path = load_vae_model(params, mode_str)

    initialize_vae_weights(vae_checkpoint_path)

    H = W = params["dataset"]["image_size"]
    image_seq_len = (vae.H // (2**len(vae.convblocks)))**2 // (
        vae.stack_factor**2)  # TODO: check this is correct
    batch_size = params[f"{mode_str}_batch_size"]
    n_channels = params.get("input_channels", 3)

    with tf.variable_scope("vae"):
        vae_logits = vae.forward(features, return_logits=True)

    # TODO: using argmax sampling for now, but is that optimal?
    tokens = tf.math.argmax(vae_logits, -1)
    img_tokens_reshaped = tf.cast(
        tf.reshape(tokens, (batch_size, image_seq_len)), tf.int32)

    # Construct mtf graph + mesh from params
    graph = mtf.Graph()
    mesh_shape = mtf.convert_to_shape(params["mesh_shape"])
    layout_rules = mtf.convert_to_layout_rules(params["layout"])

    # Mesh setup
    if params["use_tpu"]:
        var_placer, mesh_impl = simd_mesh_setup(params, mesh_shape,
                                                layout_rules)
    else:
        var_placer = None
        gpu_ids = params["gpu_ids"]
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
            mesh_shape, layout_rules, gpu_ids)

    # Build mtf mesh object
    mesh = mtf.Mesh(graph, "my_mesh", var_placer)

    model = DALLE(
        n_embd=params["n_embd"],
        text_vocab_size=params["text_vocab_size"],
        image_vocab_size=params["image_vocab_size"],
        text_seq_len=params["text_seq_len"],
        image_seq_len=image_seq_len,
        n_layers=params["n_layers"],
        n_heads=params["n_heads"],
        batch_size=batch_size,
        bf_16=params["bf_16"],
        mode=mode_str,
        params=params,
    )

    # Build mtf_features & seq length dict for getting number of microbatches
    # We need to pack inputs into a dict to pass into serialize_training_step
    features_dict = {"image_inputs": features, "text_inputs": labels}
    mtf_features = {}
    for key, x in features_dict.items():
        if x is not None:
            if key == "text_inputs":
                text_tokens = tf.reshape(x,
                                         [batch_size, params["text_seq_len"]])
                x = tf.concat(
                    (text_tokens, img_tokens_reshaped + model.text_vocab_size),
                    axis=1)
                mtf_shape = mtf.Shape([
                    model.dimensions["batch_dim"],
                    model.dimensions["total_seq_dim"]
                ])

                mtf_features["tokens"] = mtf.import_fully_replicated(mesh,
                                                                     x,
                                                                     mtf_shape,
                                                                     name=key)

            if key == "image_inputs":
                mtf_shape = mtf.Shape([
                    model.dimensions["batch_dim"],
                    mtf.Dimension("img_height_dim", vae.H),
                    mtf.Dimension("img_width_dim", vae.W),
                    mtf.Dimension("img_channel_dim", vae.num_ch),
                ])
                x = tf.reshape(x, [batch_size, H, W, n_channels])  # NHWC
                mtf_features["image_inputs"] = mtf.import_fully_replicated(
                    mesh, x, mtf_shape, name=key)

    scalar_summary("input_image", mtf_features["image_inputs"])
    if mode == tf.estimator.ModeKeys.PREDICT:
        raise NotImplementedError

    # We're not predicting, so we better be training or evaluating
    assert (mode == tf.estimator.ModeKeys.TRAIN
            or mode == tf.estimator.ModeKeys.EVAL)

    if mode == tf.estimator.ModeKeys.TRAIN:
        # Gets number of microbatches per batch for serialized training
        # if param tokens_per_mb_per_replica = None, this defaults to 1 and no microbatching is performed
        num_microbatches = int(
            mtf_transformer.utils.serialize_num_microbatches(
                batch_dim=model.dimensions["batch_dim"],
                sequence_length=model.total_seq_dim,
                mesh_shape=mesh_shape,
                layout_rules=layout_rules,
                tokens_per_microbatch_per_replica=params[
                    "tokens_per_mb_per_replica"]))
    else:
        num_microbatches = 1

    params[
        "num_microbatches"] = num_microbatches  # Add num microbatches to params

    if num_microbatches > 1:
        # For serialize_training_step we need to modify the model to output results in a dict
        def serialized_fn(mtf_features):
            loss, loss_batch = model.forward(mtf_features, return_loss=True)
            return {"loss": loss, "loss_batch": loss_batch}

        # Serialize the training step - Gradients are accumulated locally and reduced once.
        var_grads, output_dict = mtf.serialize_training_step(
            mtf_features, serialized_fn, model.dimensions["batch_dim"],
            num_microbatches)
        loss = output_dict["loss"]
        loss_batch = output_dict["loss_batch"]
    else:
        loss, loss_batch = model.forward(mtf_features, return_loss=True)

    del loss_batch  # TODO: may need this for some metrics - otherwise, remove from output

    if mode == tf.estimator.ModeKeys.TRAIN:
        # In TRAIN mode, get optimizer
        if num_microbatches > 1:
            # If we are splitting the batch into microbatches, var grads are created in the serialize_training_step fn
            # So we pass them in here
            _, update_ops, var_grads = get_optimizer(
                mesh,
                loss,
                params,
                variable_dtype=model.variable_dtype,
                inp_var_grads=var_grads)
        else:
            # Otherwise, they are created in the get_optimizer fn, so we leave inp_var_grads blank
            _, update_ops, var_grads = get_optimizer(
                mesh, loss, params, variable_dtype=model.variable_dtype)
        # Log summaries to tensorboard
        scalar_summary("loss", loss)

    # Gets & prints info about no. trainable vars in the model & dimension names
    get_graph_info(graph)

    # 'lowers' mtf tensors into a tf graph - this enables us to export results as tf tensors
    lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=False)

    tf_loss = lowering.export_to_tf_tensor(loss)
    tf_loss = tf.cast(tf_loss, tf.float32)

    if mode == tf.estimator.ModeKeys.TRAIN:
        # Use our patched version until mtf updates theirs
        host_call = create_host_call(params['model_path'])
        mtf.utils.remove_summaries()

        # Creates train_op
        tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
        tf_update_ops.append(tf.assign_add(
            global_step, 1))  # Need to manually increment global_step
        train_op = tf.group(tf_update_ops)

    with mtf.utils.outside_all_rewrites():
        # Copy master variables to slices. Must be called first.
        restore_hook = mtf.MtfRestoreHook(lowering)
        if mode == tf.estimator.ModeKeys.TRAIN:
            # Set up the checkpoint server and return the TPUEstimatorSpec
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=params.get(
                                       "max_checkpoints", 5),
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                params["model_path"],
                save_steps=params["steps_per_checkpoint"],
                saver=saver,
                listeners=[saver_listener])

            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                host_call=host_call,
                train_op=train_op,
                training_hooks=[restore_hook, saver_hook])

        elif mode == tf.estimator.ModeKeys.EVAL:
            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.EVAL,
                evaluation_hooks=[restore_hook],
                loss=tf_loss,
                eval_metrics=None)
示例#10
0
文件: utils.py 项目: gavinljj/mesh
def run(tpu_job_name,
        tpu,
        gcp_project,
        tpu_zone,
        model_dir,
        model_type="bitransformer",
        vocabulary=gin.REQUIRED,
        train_dataset_fn=None,
        eval_dataset_fn=None,
        dataset_split="train",
        autostack=True,
        checkpoint_path="",
        mode="train",
        iterations_per_loop=100,
        save_checkpoints_steps=1000,
        eval_steps=10,
        train_steps=1000000,
        batch_size=auto_batch_size,
        sequence_length=gin.REQUIRED,
        mesh_shape=gin.REQUIRED,
        layout_rules=gin.REQUIRED,
        get_components_fn=None):
  """Run training/eval/inference.

  Args:
    tpu_job_name: string, name of TPU worker binary
    tpu: string, the Cloud TPU to use for training
    gcp_project: string, project name for the Cloud TPU-enabled project
    tpu_zone: string, GCE zone where the Cloud TPU is located in
    model_dir: string, estimator model_dir
    model_type: a string - either "bitransformer", "lm" or "aligned"
    vocabulary: a vocabulary.Vocabulary
    train_dataset_fn: A function returning a tf.data.Dataset. Must be provided
      for mode=train
    eval_dataset_fn: A function returning a tf.data.Dataset. Must be provided
      for model=eval
    dataset_split: a string
    autostack: boolean, internally combine variables
    checkpoint_path: a string - which checkpoint to load for inference
    mode: string, train/evaluate/infer
    iterations_per_loop: integer, steps per train loop
    save_checkpoints_steps: integer, steps per checkpoint
    eval_steps: integer, number of evaluation steps
    train_steps: Total number of training steps.
    batch_size: An integer or a function with the same signature as
      auto_batch_size().  Mini-batch size for the training. Note that this is
      the global batch size and not the per-shard batch size.
    sequence_length: an integer
    mesh_shape: an input to mtf.convert_to_shape()
    layout_rules: an input to mtf.convert_to_layout_rules()
    get_components_fn: an optional function that gets a list of tuples of
      (metric_names, component) for each component. Required if mode is
      "continuous_eval"
  """
  if not isinstance(batch_size, int):
    batch_size = batch_size(sequence_length, mesh_shape, layout_rules)

  tf.logging.info("mode=%s" % mode,)
  tf.logging.info("batch_size=%s" % batch_size,)
  tf.logging.info("sequence_length=%s" % sequence_length,)
  tf.logging.info("mesh_shape=%s" % mesh_shape,)
  tf.logging.info("layout_rules=%s" % layout_rules,)

  if mode == "train" and dataset_split != "train":
    raise ValueError("mode==\"train\" requires dataset_split==\"train\"")

  mesh_shape = mtf.convert_to_shape(mesh_shape)
  layout_rules = mtf.convert_to_layout_rules(layout_rules)

  cluster = tf.contrib.cluster_resolver.TPUClusterResolver(
      tpu if (tpu) else "", zone=tpu_zone, project=gcp_project)

  tf.logging.info(
      "Building TPUConfig with tpu_job_name={}".format(tpu_job_name)
  )
  my_tpu_config = tpu_config.TPUConfig(
      tpu_job_name=tpu_job_name,
      iterations_per_loop=iterations_per_loop,
      num_cores_per_replica=1,
      per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST,
  )

  run_config = tpu_config.RunConfig(
      cluster=cluster,
      model_dir=model_dir,
      save_checkpoints_steps=save_checkpoints_steps,
      tpu_config=my_tpu_config)

  transformer_model = build_model(
      model_type=model_type,
      vocab_size=vocabulary.vocab_size,
      layout_rules=layout_rules,
      mesh_shape=mesh_shape)

  model_fn = tpu_estimator_model_fn(
      model_type=model_type,
      transformer_model=transformer_model,
      model_dir=model_dir,
      use_tpu=tpu,
      mesh_shape=mesh_shape,
      layout_rules=layout_rules,
      batch_size=batch_size,
      sequence_length=sequence_length,
      autostack=autostack,
      metric_names=None)

  estimator = tpu_estimator.TPUEstimator(
      model_fn=model_fn,
      config=run_config,
      train_batch_size=batch_size,
      eval_batch_size=batch_size,
      predict_batch_size=batch_size,
      use_tpu=tpu,
      export_to_tpu=False,
      params={})

  if mode == "train":
    if train_dataset_fn is None:
      raise ValueError("Must provide train_dataset_fn through gin for train.")
    def input_fn(params):
      del params
      dataset = train_dataset_fn(batch_size=batch_size,
                                 sequence_length=sequence_length,
                                 vocabulary=vocabulary,
                                 dataset_split=dataset_split)
      return dataset

    estimator.train(input_fn=input_fn, max_steps=train_steps)
  elif mode == "continuous_eval":
    if get_components_fn is None:
      raise ValueError("Must provide get_components_fn through gin for eval.")
    if eval_dataset_fn is None:
      raise ValueError("Must provide eval_dataset_fn through gin for eval.")
    metrics_inputs = get_components_fn()
    for _ in tf.contrib.training.checkpoints_iterator(estimator.model_dir):
      for metric_names, component in metrics_inputs:
        tf.logging.info("Evaluating {}".format(component.__dict__))
        tf.logging.info("on split {}".format(dataset_split))
        # Prepend eval tag and split name to metric names
        metric_names = [
            "eval/{}/{}".format(dataset_split, n) for n in metric_names
        ]
        # Regenerate the estimator
        model_fn = tpu_estimator_model_fn(
            model_type=model_type,
            transformer_model=transformer_model,
            model_dir=model_dir,
            use_tpu=tpu,
            mesh_shape=mesh_shape,
            layout_rules=layout_rules,
            batch_size=batch_size,
            sequence_length=sequence_length,
            autostack=autostack,
            metric_names=metric_names)
        estimator = tpu_estimator.TPUEstimator(
            model_fn=model_fn,
            config=run_config,
            train_batch_size=batch_size,
            eval_batch_size=batch_size,
            predict_batch_size=batch_size,
            use_tpu=tpu,
            export_to_tpu=False,
            params={})
        def input_fn(params):
          del params
          dataset = eval_dataset_fn(component,  # pylint: disable=cell-var-from-loop
                                    batch_size=batch_size,
                                    sequence_length=sequence_length,
                                    vocabulary=vocabulary,
                                    dataset_split=dataset_split,
                                    pack=False)
          return dataset

        eval_args = {"eval": (input_fn, eval_steps)}
        _ = evaluate(estimator, eval_args)

  elif mode == "infer":
    decode_from_file(
        estimator,
        vocabulary=vocabulary,
        model_type=model_type,
        batch_size=batch_size,
        sequence_length=sequence_length,
        checkpoint_path=checkpoint_path)
  else:
    raise ValueError(
        "unknown mode %s - must be train/evaluate/continuous_eval/infer" % mode)
示例#11
0
def run(tpu_job_name,
        tpu,
        gcp_project,
        tpu_zone,
        model_dir,
        model_type="bitransformer",
        vocabulary=gin.REQUIRED,
        train_dataset_fn=None,
        eval_dataset_fn=None,
        dataset_split="train",
        autostack=True,
        checkpoint_path="",
        mode="train",
        iterations_per_loop=100,
        save_checkpoints_steps=1000,
        keep_checkpoint_max=10,
        batch_size=("tokens_per_replica", 2048),
        train_steps=auto_train_steps,
        sequence_length=gin.REQUIRED,
        mesh_shape=gin.REQUIRED,
        layout_rules=gin.REQUIRED,
        num_eval_examples=None,
        get_components_fn=None,
        compute_metrics_from_file_fn=None,
        learning_rate_schedule=None,
        optimizer=None):
    """Run training/eval/inference.

  Args:
    tpu_job_name: string, name of TPU worker binary
    tpu: string, the Cloud TPU to use for training
    gcp_project: string, project name for the Cloud TPU-enabled project
    tpu_zone: string, GCE zone where the Cloud TPU is located in
    model_dir: string, estimator model_dir
    model_type: a string - either "bitransformer", "bi_student_teacher", lm" or
      "aligned"
    vocabulary: a vocabulary.Vocabulary or
      (inputs_vocabulary, targets_vocabulary) tuple.
    train_dataset_fn: A function returning a tf.data.Dataset. Must be provided
      for mode=train
    eval_dataset_fn: A function returning a tf.data.Dataset. Must be provided
      for model=eval
    dataset_split: a string
    autostack: boolean, internally combine variables
    checkpoint_path: a string - which checkpoint to load for inference
    mode: string, train/evaluate/infer
    iterations_per_loop: integer, steps per train loop
    save_checkpoints_steps: integer, steps per checkpoint
    keep_checkpoint_max: an integer, keep up to this many checkpoints
    batch_size: An integer or a (method, value) pair to pass to
      compute_batch_size(). Note that this is
      the global batch size and not the per-shard batch size.
    train_steps: An integer or a function with the same signature as
      auto_train_steps().  Total number of training steps.
    sequence_length: an integer
    mesh_shape: an input to mtf.convert_to_shape()
    layout_rules: an input to mtf.convert_to_layout_rules()
    num_eval_examples: maximum number of examples per task to use for continuous
      eval.
    get_components_fn: an optional function that returns a list of tuples of
      (metric_names, component) for each component.
      Required if mode is "continuous_eval."
    compute_metrics_from_file_fn: an optional function that takes in: component,
      metric names (list of strs), targets (list of strs), predictions (list of
      strs), dataset_split (str), and tb_summary_dir (str), runs metrics on
      targets and predictions, and returns a dictionary of metrics and their
      computed values. Required if mode is "continuous_eval."
    learning_rate_schedule: an optional function taking the scalar name
      argument `step` and the numeric argument `total_train_steps` and return
      the scalar learning rate
    optimizer: a class extending optimize.Optimizer, required for training
  """
    if not isinstance(batch_size, int):
        batch_size = compute_batch_size(sequence_length, mesh_shape,
                                        layout_rules, batch_size)

    if not isinstance(train_steps, int):
        train_steps = train_steps(batch_size, sequence_length)

    if callable(learning_rate_schedule):
        learning_rate_schedule = functools.partial(
            learning_rate_schedule, total_train_steps=train_steps)

    tf.logging.info("model_type=%s" % model_type, )
    tf.logging.info("mode=%s" % mode, )
    tf.logging.info("sequence_length=%s" % sequence_length, )
    tf.logging.info("batch_size=%s" % batch_size, )
    tf.logging.info("train_steps=%s" % train_steps, )
    tf.logging.info("mesh_shape=%s" % mesh_shape, )
    tf.logging.info("layout_rules=%s" % layout_rules, )

    if mode == "train" and dataset_split != "train":
        raise ValueError("mode==\"train\" requires dataset_split==\"train\"")

    mesh_shape = mtf.convert_to_shape(mesh_shape)
    layout_rules = mtf.convert_to_layout_rules(layout_rules)

    cluster = tf.contrib.cluster_resolver.TPUClusterResolver(
        tpu if (tpu) else "", zone=tpu_zone, project=gcp_project)

    tf.logging.info(
        "Building TPUConfig with tpu_job_name={}".format(tpu_job_name))
    my_tpu_config = tpu_config.TPUConfig(
        tpu_job_name=tpu_job_name,
        iterations_per_loop=iterations_per_loop,
        num_cores_per_replica=1,
        per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST,
    )

    run_config = tpu_config.RunConfig(
        cluster=cluster,
        model_dir=model_dir,
        tpu_config=my_tpu_config,
        # We use a saver hook, so disable checkpoints here to prevent double
        # saving.
        save_checkpoints_steps=None,
        save_checkpoints_secs=None)

    transformer_model = build_model(
        model_type=model_type,
        input_vocab_size=inputs_vocabulary(vocabulary).vocab_size,
        output_vocab_size=targets_vocabulary(vocabulary).vocab_size,
        layout_rules=layout_rules,
        mesh_shape=mesh_shape)

    model_fn = tpu_estimator_model_fn(
        model_type=model_type,
        transformer_model=transformer_model,
        model_dir=model_dir,
        use_tpu=tpu,
        mesh_shape=mesh_shape,
        layout_rules=layout_rules,
        batch_size=batch_size,
        sequence_length=sequence_length,
        autostack=autostack,
        learning_rate_schedule=learning_rate_schedule,
        keep_checkpoint_max=keep_checkpoint_max,
        save_checkpoints_steps=save_checkpoints_steps,
        optimizer=optimizer)

    estimator = tpu_estimator.TPUEstimator(model_fn=model_fn,
                                           config=run_config,
                                           train_batch_size=batch_size,
                                           eval_batch_size=batch_size,
                                           predict_batch_size=batch_size,
                                           use_tpu=tpu,
                                           export_to_tpu=False,
                                           params={})

    if mode == "train":
        if train_dataset_fn is None:
            raise ValueError(
                "Must provide train_dataset_fn through gin for train.")

        def input_fn(params):
            del params
            dataset = train_dataset_fn(batch_size=batch_size,
                                       sequence_length=sequence_length,
                                       vocabulary=vocabulary,
                                       dataset_split=dataset_split)
            return dataset

        estimator.train(input_fn=input_fn, max_steps=train_steps)
    elif mode == "continuous_eval":
        if eval_dataset_fn is None:
            raise ValueError(
                "Must provide eval_dataset_fn through gin for eval.")
        if get_components_fn is None:
            raise ValueError(
                "Must provide get_components_fn through gin for eval.")
        if compute_metrics_from_file_fn is None:
            raise ValueError(
                "Must provide compute_metrics_from_file_fn through gin for eval."
            )

        metrics_inputs = get_components_fn()
        for ckpt in tf.contrib.training.checkpoints_iterator(
                estimator.model_dir):
            for metric_names, component in metrics_inputs:
                if not metric_names:
                    tf.logging.info("Skipping %s", component.__dict__)
                    continue
                tf.logging.info("Evaluating %s on metrics %s",
                                component.tfds_name, component.metric_names)
                tf.logging.info("on split %s", dataset_split)

                # Regenerate the estimator
                model_fn = tpu_estimator_model_fn(
                    model_type=model_type,
                    transformer_model=transformer_model,
                    model_dir=model_dir,
                    use_tpu=tpu,
                    mesh_shape=mesh_shape,
                    layout_rules=layout_rules,
                    batch_size=batch_size,
                    sequence_length=sequence_length,
                    autostack=autostack,
                    keep_checkpoint_max=keep_checkpoint_max,
                    save_checkpoints_steps=save_checkpoints_steps)
                estimator = tpu_estimator.TPUEstimator(
                    model_fn=model_fn,
                    config=run_config,
                    train_batch_size=batch_size,
                    eval_batch_size=batch_size,
                    predict_batch_size=batch_size,
                    use_tpu=tpu,
                    export_to_tpu=False,
                    params={})

                # Extra eval_dataset_fn call to get the dataset_size and an extra
                # dataset object to write out targets. We need to use a separate graph
                # because estimator finalizes the default graph after iterating over the
                # dataset.
                dataset_graph = tf.Graph()
                with dataset_graph.as_default():
                    dataset, dataset_size, padded_dataset_size = eval_dataset_fn(
                        component,  # pylint: disable=cell-var-from-loop
                        batch_size=batch_size,
                        sequence_length=sequence_length,
                        vocabulary=vocabulary,
                        dataset_split=dataset_split,
                        pack=False,
                        max_dataset_size=num_eval_examples)

                def input_fn(params):
                    del params
                    dataset, _, _ = eval_dataset_fn(
                        component,  # pylint: disable=cell-var-from-loop
                        batch_size=batch_size,
                        sequence_length=sequence_length,
                        vocabulary=vocabulary,
                        dataset_split=dataset_split,
                        pack=False,
                        max_dataset_size=num_eval_examples)
                    return dataset

                dataset_name = component.tfds_name.replace("/", "-").replace(
                    ":", "-")
                output_filename = os.path.join(
                    model_dir, "{}-{}-decoded".format(dataset_name,
                                                      dataset_split))
                pred_output_filename = output_filename + "-preds-test"
                target_output_filename = output_filename + "-targets-test"
                decodes = decode(estimator,
                                 input_fn,
                                 dataset_size,
                                 padded_dataset_size,
                                 batch_size,
                                 vocabulary,
                                 checkpoint_path=checkpoint_path)
                with dataset_graph.as_default():
                    log_pred_target(
                        decodes,
                        dataset,
                        dataset_size,
                        vocabulary,
                        pred_output_filename=pred_output_filename,
                        target_output_filename=target_output_filename)
                tf.logging.info("Evaluating metrics: {}".format(metric_names))
                tb_summary_dir = os.path.join(
                    model_dir,
                    "{}_eval".format("eval" if dataset_split ==
                                     "validation" else dataset_split))
                summary_writer = tf.summary.FileWriter(tb_summary_dir)
                _ = compute_metrics_from_file_fn(component,
                                                 pred_output_filename,
                                                 target_output_filename,
                                                 dataset_split,
                                                 tb_summary_dir,
                                                 ckpt,
                                                 summary_writer=summary_writer)

    elif mode == "infer":
        decode_from_file(estimator,
                         vocabulary=vocabulary,
                         model_type=model_type,
                         batch_size=batch_size,
                         sequence_length=sequence_length,
                         checkpoint_path=checkpoint_path)
    else:
        raise ValueError(
            "unknown mode %s - must be train/continuous_eval/infer" % mode)
示例#12
0
  def estimator_model_fn(cls,
                         hparams,
                         features,
                         labels,
                         mode,
                         config=None,
                         params=None,
                         decode_hparams=None,
                         use_tpu=False):
    hparams = copy.deepcopy(hparams)
    hparams.use_tpu = use_tpu
    # merge decode_hparams into hparams if present
    if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None:
      for k, v in six.iteritems(decode_hparams.values()):
        if hasattr(hparams, k) and getattr(hparams, k) != v:
          tf.logging.warning("Overriding hparams.%s with %s from decode_hparams"
                             % (k, v))
        setattr(hparams, k, v)

    # Instantiate model
    data_parallelism = None
    if not use_tpu and config:
      data_parallelism = config.data_parallelism
    model = cls(
        hparams,
        mode,
        data_parallelism=data_parallelism,
        decode_hparams=decode_hparams)

    global_step = tf.train.get_global_step()

    mesh_shape = mtf.convert_to_shape(hparams.mesh_shape)
    layout_rules = mtf.convert_to_layout_rules(hparams.layout)
    if use_tpu:
      ctx = params["context"]
      num_hosts = ctx.num_hosts
      host_placement_fn = ctx.tpu_host_placement_function
      device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
      # TODO(ylc): Better estimation of replica cache size?
      replica_cache_size = 300 * 1000000  # 300M per replica
      # Worker 0 caches all the TPU binaries.
      worker0_mem = replica_cache_size * ctx.num_replicas
      devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
      var_placer = mtf.utils.BalancedVariablePlacer(device_list,
                                                    devices_memeory_usage)
      mesh_devices = [""] * mesh_shape.size
      mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
          mesh_shape, layout_rules, mesh_devices, ctx.device_assignment)
    else:
      var_placer = None
      if data_parallelism is None or len(data_parallelism.ps_devices) == 1:
        mesh_devices = [""] * mesh_shape.size
      else:
        assert len(data_parallelism.ps_devices) == mesh_shape.size
        mesh_devices = data_parallelism.ps_devices
      mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
          mesh_shape, layout_rules, mesh_devices)

    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh", var_placer)
    # PREDICT mode
    if mode == tf.estimator.ModeKeys.PREDICT:
      return model.estimator_spec_predict(features, mesh, mesh_impl, use_tpu)

    logits, loss = model.mtf_model_fn(features, mesh)
    if use_tpu and logits is not None:
      logits = mtf.anonymize(logits)

    # TRAIN mode
    if mode == tf.estimator.ModeKeys.TRAIN:
      var_grads = mtf.gradients(
          [loss], [v.outputs[0] for v in graph.trainable_variables])
      lr = learning_rate.learning_rate_schedule(hparams)
      tf.summary.scalar("learning_rate", lr)
      mtf_lr = mtf.import_tf_tensor(
          mesh, tf.convert_to_tensor(lr, dtype=tf.float32), mtf.Shape([]))
      optimizer = mtf.optimize.make_optimizer(hparams, mtf_lr)
      update_ops = []
      for grad, var in zip(var_grads, graph.trainable_variables):
        update_ops.extend(optimizer.apply_grad(grad, var))

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})

    tf_loss = lowering.export_to_tf_tensor(loss)
    tf_loss = tf.to_float(tf_loss)
    if logits and mode != tf.estimator.ModeKeys.TRAIN:
      tf_logits = lowering.export_to_tf_tensor(logits)

    if mode == tf.estimator.ModeKeys.TRAIN:
      tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
      tf_update_ops.append(tf.assign_add(global_step, 1))
      # tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
      train_op = tf.group(tf_update_ops)

    with mtf.utils.outside_all_rewrites():
      # Copy master variables to slices. Must be called first.
      restore_hook = mtf.MtfRestoreHook(lowering)
      saver = tf.train.Saver(
          tf.global_variables(),
          sharded=True,
          max_to_keep=10,
          keep_checkpoint_every_n_hours=2,
          defer_build=False,
          save_relative_paths=True)
      tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
      saver_listener = mtf.MtfCheckpointSaverListener(lowering)
      saver_hook = tf.train.CheckpointSaverHook(
          hparams.model_dir,
          save_steps=1000,
          saver=saver,
          listeners=[saver_listener])

    # EVAL mode
    if mode == tf.estimator.ModeKeys.EVAL:
      tf_logits = lowering.export_to_tf_tensor(logits)
      return model.estimator_spec_eval(features, tf_logits, labels, tf_loss,
                                       restore_hook, use_tpu)

    if use_tpu:
      # TPU host call. Important: need to be called before remove_summaries()
      if hparams.tpu_enable_host_call:
        host_call = t2t_model.create_host_call(hparams.model_dir)
      else:
        host_call = None

      t2t_model.remove_summaries()
      return tpu_estimator.TPUEstimatorSpec(
          mode=tf.estimator.ModeKeys.TRAIN,
          loss=tf_loss,
          train_op=train_op,
          host_call=host_call,
          training_hooks=[restore_hook, saver_hook])
    else:
      return tf.estimator.EstimatorSpec(
          tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op,
          training_chief_hooks=[restore_hook, saver_hook])
示例#13
0
def model_fn(features, labels, mode, params):
    """The model_fn argument for creating an Estimator."""
    global_step = tf.train.get_global_step()
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    logits, loss = model_backbone(features, labels, mesh)

    variables = graph._all_variables
    for v in variables:
        logger.debug("[parameter] (name,shape,dtype): ({},{},{})".format(
            v.name, v.shape, v.dtype.master_dtype))
    mesh_shape = mtf.convert_to_shape(args_opt.mesh_shape)
    # layout_rules = mtf.auto_mtf.layout(graph, mesh_shape, [logits, loss])
    mesh_shape = mtf.convert_to_shape(mesh_shape)
    estimator = memory_estimator.MemoryEstimator(graph, mesh_shape,
                                                 [logits, loss])
    optimizer = layout_optimizer.LayoutOptimizer(estimator,
                                                 scheduler_alg="NAIVE")
    layout_rules = mtf.convert_to_layout_rules(optimizer.solve())

    logger.info("[auto mtf search] strategy: {}".format(layout_rules))
    mesh_devices = ["gpu:{}".format(i) for i in range(int(args_opt.num_gpus))]
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        mesh_shape, layout_rules, mesh_devices)

    if mode == tf.estimator.ModeKeys.TRAIN:
        var_grads = mtf.gradients(
            [loss], [v.outputs[0] for v in graph.trainable_variables])
        optimizer = mtf.optimize.SgdOptimizer(0.01)
        # optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
        update_ops = optimizer.apply_grads(var_grads,
                                           graph.trainable_variables)

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    restore_hook = mtf.MtfRestoreHook(lowering)

    tf_logits = lowering.export_to_tf_tensor(logits)
    if mode != tf.estimator.ModeKeys.PREDICT:
        tf_loss = lowering.export_to_tf_tensor(loss)

    if mode == tf.estimator.ModeKeys.TRAIN:
        tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
        tf_update_ops.append(tf.assign_add(global_step, 1))
        train_op = tf.group(tf_update_ops)

        predicts = tf.sigmoid(tf_logits)
        # predict = lowering.export_to_tf_tensor(predicts)
        predicts = tf.where(predicts > 0.5, tf.ones_like(predicts),
                            tf.zeros_like(predicts))
        # print("="*100)
        # print(labels.shape)
        # print(predicts.shape)
        accuracy = tf.metrics.accuracy(labels=labels, predictions=predicts)

        # Name tensors to be logged with LoggingTensorHook.
        tf.identity(tf_loss, "cross_entropy")
        tf.identity(accuracy[1], name="train_accuracy")

        logging_hook = tf.train.LoggingTensorHook(every_n_iter=100,
                                                  tensors={
                                                      'loss': 'cross_entropy',
                                                      'acc': 'train_accuracy'
                                                  })

        # restore_hook must come before saver_hook
        return tf.estimator.EstimatorSpec(
            tf.estimator.ModeKeys.TRAIN,
            loss=tf_loss,
            train_op=train_op,
            training_chief_hooks=[restore_hook, logging_hook])
示例#14
0
文件: utils.py 项目: masak1112/mesh
def run(tpu_job_name,
        tpu,
        gcp_project,
        tpu_zone,
        model_dir,
        model_type="bitransformer",
        vocabulary=gin.REQUIRED,
        train_dataset_fn=None,
        eval_dataset_fn=None,
        dataset_split="train",
        autostack=True,
        checkpoint_step=None,
        mode="train",
        iterations_per_loop=100,
        save_checkpoints_steps=1000,
        keep_checkpoint_max=10,
        eval_summary_dir=None,
        batch_size=("tokens_per_replica", 2048),
        train_steps=auto_train_steps,
        sequence_length=gin.REQUIRED,
        mesh_shape=gin.REQUIRED,
        layout_rules=gin.REQUIRED,
        learning_rate_schedule=None,
        optimizer=None,
        predict_fn=None):
  """Run training/eval/inference.

  Args:
    tpu_job_name: string, name of TPU worker binary
    tpu: string, the Cloud TPU to use for training
    gcp_project: string, project name for the Cloud TPU-enabled project
    tpu_zone: string, GCE zone where the Cloud TPU is located in
    model_dir: string, estimator model_dir
    model_type: a string - either "bitransformer", "bi_student_teacher", lm" or
      "aligned"
    vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary,
      targets_vocabulary) tuple.
    train_dataset_fn: A function returning a tf.data.Dataset. Must be provided
      for mode="train". Should accept the following arguments:
        - batch_size: int, number of entries in each batch.
        - sequence_length: int, length of each packed or padded sequence.
        - vocabulary: Vocabulary instance to use for encoding.
        - dataset_split: str, which dataset split to load.
    eval_dataset_fn: A function returning a list of dataset.EvalDataset tuples.
      Must be provided for mode="eval". Should accept the following arguments:
        - batch_size: int, number of entries in each batch.
        - sequence_length: int, length of each packed or padded sequence.
        - vocabulary: Vocabulary instance to use for encoding.
        - dataset_split: str, which dataset split to load.
      dataset.EvalDataset tuples are namedtuples with the following fields:
        - name: string, the task name
        - dataset_fn: function which returns a tf.data.Dataset of tokenized and
          padded examples. Must not require any arguments and must include the
          feature keys 'inputs' and 'targets_plaintext'.
        - postprocess_fn: function which converts model outputs to evalable str
        - list_of_metric_fns: list of metric functions with the call signature
          `metric_fn(targets, predictions)` which return either a scalar value
          or a dict mapping submetric names to scalar values. TensorBoard
          summaries and other tags will be written out using
          `metric_fn.__name__`.
        - dataset_size: number of entries in the dataset.
        - padded_dataset_size: number of entries in the dataset after padding.
    dataset_split: a string
    autostack: boolean, internally combine variables
    checkpoint_step: int, list of ints, or None. Only used when mode="eval" or
      mode="infer". If an int or list of ints, evaluation or inference will be
      run on the checkpoint files  in `model_dir` whose global steps are closest
      to the global steps provided. If None and mode="eval", run eval
      continuously waiting for new checkpoints via
      `tf.contrib.training.checkpoints_iterator`.
    mode: string, train/eval/infer
    iterations_per_loop: integer, steps per train loop
    save_checkpoints_steps: integer, steps per checkpoint
    keep_checkpoint_max: an integer, keep up to this many checkpoints
    eval_summary_dir: str, path to write TensorBoard events file summaries for
      eval. If None, use model_dir/eval_{split}.
    batch_size: An integer or a (method, value) pair to pass to
      compute_batch_size(). Note that this is the global batch size and not the
      per-shard batch size.
    train_steps: An integer or a function with the same signature as
      auto_train_steps().  Total number of training steps.
    sequence_length: an integer
    mesh_shape: an input to mtf.convert_to_shape()
    layout_rules: an input to mtf.convert_to_layout_rules()
    learning_rate_schedule: an optional function taking the scalar name argument
      `step` and the numeric argument `total_train_steps` and return the scalar
      learning rate
    optimizer: a class extending optimize.Optimizer, required for training
    predict_fn: an optional function that can be used to override the default
      transformer prediction behavior. Must return a tensor of shape [batch_dim,
      length_dim] that will be the prediction for each example. Must accept the
      following arguments:
        - model: a Unitransformer or Bitransformer
        - features: a dict representing an example. Every value will be an
          mtf.Tensor with shape [batch_dim, length_dim].
        - variable_dtype: an mtf.VariableDType
  """
  if not isinstance(batch_size, int):
    batch_size = compute_batch_size(
        sequence_length, mesh_shape, layout_rules, batch_size)

  if not isinstance(train_steps, int):
    train_steps = train_steps(batch_size, sequence_length)

  if callable(learning_rate_schedule):
    learning_rate_schedule = functools.partial(
        learning_rate_schedule, total_train_steps=train_steps)

  tf.logging.info("model_type=%s" % model_type,)
  tf.logging.info("mode=%s" % mode,)
  tf.logging.info("sequence_length=%s" % sequence_length,)
  tf.logging.info("batch_size=%s" % batch_size,)
  tf.logging.info("train_steps=%s" % train_steps,)
  tf.logging.info("mesh_shape=%s" % mesh_shape,)
  tf.logging.info("layout_rules=%s" % layout_rules,)

  if mode == "train" and dataset_split != "train":
    raise ValueError("mode==\"train\" requires dataset_split==\"train\"")

  mesh_shape = mtf.convert_to_shape(mesh_shape)
  layout_rules = mtf.convert_to_layout_rules(layout_rules)

  cluster = tf.contrib.cluster_resolver.TPUClusterResolver(
      tpu if (tpu) else "", zone=tpu_zone, project=gcp_project)

  tf.logging.info(
      "Building TPUConfig with tpu_job_name={}".format(tpu_job_name)
  )
  my_tpu_config = tpu_config.TPUConfig(
      tpu_job_name=tpu_job_name,
      iterations_per_loop=iterations_per_loop,
      num_cores_per_replica=1,
      per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST,
  )

  run_config = tpu_config.RunConfig(
      cluster=cluster,
      model_dir=model_dir,
      tpu_config=my_tpu_config,
      # We use a saver hook, so disable checkpoints here to prevent double
      # saving.
      save_checkpoints_steps=None,
      save_checkpoints_secs=None)

  transformer_model = build_model(
      model_type=model_type,
      input_vocab_size=inputs_vocabulary(vocabulary).vocab_size,
      output_vocab_size=targets_vocabulary(vocabulary).vocab_size,
      layout_rules=layout_rules,
      mesh_shape=mesh_shape)

  model_fn = tpu_estimator_model_fn(
      model_type=model_type,
      transformer_model=transformer_model,
      model_dir=model_dir,
      use_tpu=tpu,
      mesh_shape=mesh_shape,
      layout_rules=layout_rules,
      batch_size=batch_size,
      sequence_length=sequence_length,
      autostack=autostack,
      learning_rate_schedule=learning_rate_schedule,
      keep_checkpoint_max=keep_checkpoint_max,
      save_checkpoints_steps=save_checkpoints_steps,
      optimizer=optimizer,
      predict_fn=predict_fn)

  estimator = tpu_estimator.TPUEstimator(
      model_fn=model_fn,
      config=run_config,
      train_batch_size=batch_size,
      eval_batch_size=batch_size,
      predict_batch_size=batch_size,
      use_tpu=tpu,
      export_to_tpu=False,
      params={})

  if mode == "train":
    if train_dataset_fn is None:
      raise ValueError("Must provide train_dataset_fn through gin for train.")
    def input_fn(params):
      del params
      dataset = train_dataset_fn(batch_size=batch_size,
                                 sequence_length=sequence_length,
                                 vocabulary=vocabulary,
                                 dataset_split=dataset_split)
      return dataset

    estimator.train(input_fn=input_fn, max_steps=train_steps)

  elif mode == "eval":
    if eval_dataset_fn is None:
      raise ValueError("Must provide eval_dataset_fn through gin for eval.")

    eval_datasets = eval_dataset_fn(
        batch_size=batch_size,
        sequence_length=sequence_length,
        vocabulary=vocabulary,
        dataset_split=dataset_split,
    )

    # Pre-load in all of the targets once before entering continuous eval loop
    cached_targets = {}
    # Need to create a separate graph for loading in plaintext targets
    # or else TF will complain that we modified the graph
    with tf.Graph().as_default():
      for eval_dataset in eval_datasets:
        eval_dataset = transformer_dataset.EvalDataset(*eval_dataset)
        # Only cache targets for those tasks with eval functions provides
        if eval_dataset.metric_fns:
          ds = eval_dataset.dataset_fn()
          # De-batch the dataset
          ds = ds.flat_map(tf.data.Dataset.from_tensor_slices)
          ds = tfds.as_numpy(ds)
          targets = [
              eval_dataset.postprocess_fn(d["targets_plaintext"]) for d in ds
          ]
          targets = targets[:eval_dataset.dataset_size]
          cached_targets[eval_dataset.name] = targets

    for checkpoint_path in get_checkpoint_iterator(checkpoint_step, model_dir):
      for eval_dataset in eval_datasets:
        eval_dataset = transformer_dataset.EvalDataset(*eval_dataset)
        if not eval_dataset.metric_fns:
          tf.logging.info(
              "Skipping %s because metric_fns is empty", eval_dataset.name
          )
          continue
        metric_names = [metric.__name__ for metric in eval_dataset.metric_fns]
        tf.logging.info(
            "Evaluating %s on metrics %s", eval_dataset.name, metric_names
        )
        tf.logging.info("on split %s", dataset_split)

        def input_fn(params):
          del params
          ds = eval_dataset.dataset_fn()
          # Only pass those variables which will be used for decoding
          ds = ds.map(
              lambda x: {k: v for k, v in x.items() if k in _INPUT_FEATURES}
          )
          return ds

        decodes = decode(
            estimator,
            input_fn,
            eval_dataset.dataset_size,
            eval_dataset.padded_dataset_size,
            batch_size,
            vocabulary,
            checkpoint_path=checkpoint_path,
        )
        predictions = [eval_dataset.postprocess_fn(d) for d in decodes]
        # TODO(craffel): Log predictions and targets.

        eval_summary_dir = eval_summary_dir or os.path.join(
            model_dir, "{}_eval".format(dataset_split)
        )
        summary_writer = tf.summary.FileWriter(eval_summary_dir)
        global_step = int(get_step_from_checkpoint_path(checkpoint_path))
        for metric_fn in eval_dataset.metric_fns:
          summary = tf.Summary()
          tag = "eval/{}/{}/{}".format(
              eval_dataset.name, dataset_split, metric_fn.__name__
          )
          targets = cached_targets[eval_dataset.name]
          metric_result = metric_fn(targets, predictions)
          if isinstance(metric_result, dict):
            tags = ["{}.{}".format(tag, key) for key in metric_result]
            metric_values = metric_result.values()
          else:
            tags, metric_values = [tag], [metric_result]
          for tag, metric_value in zip(tags, metric_values):
            tf.logging.info(
                "%s at step %d: %.3f", tag, global_step, metric_value
            )
            summary.value.add(tag=tag, simple_value=metric_value)
            summary_writer.add_summary(summary, global_step)
        summary_writer.flush()

  elif mode == "infer":
    for checkpoint_path in get_checkpoint_iterator(checkpoint_step, model_dir):
      decode_from_file(
          estimator,
          vocabulary=vocabulary,
          model_type=model_type,
          batch_size=batch_size,
          sequence_length=sequence_length,
          checkpoint_path=checkpoint_path)
  else:
    raise ValueError(
        "unknown mode %s - must be train/eval/infer" % mode)
示例#15
0
def model_fn(features, labels, mode, params):
    # Get global step
    global_step = tf.train.get_global_step()

    # Construct mtf graph + mesh from params
    graph = mtf.Graph()
    mesh_shape = mtf.convert_to_shape(params["mesh_shape"])
    layout_rules = mtf.convert_to_layout_rules(params["layout"])

    # Mesh setup
    if params["use_tpu"]:
        var_placer, mesh_impl = simd_mesh_setup(params, mesh_shape,
                                                layout_rules)
    else:
        var_placer = None
        gpu_ids = params["gpu_ids"]
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
            mesh_shape, layout_rules, gpu_ids)

    # Trainable variable precision
    # Store to checkpoints in master type, train in slice type, compute in activation type
    if params["precision"] == "bfloat16":
        variable_dtype = mtf.VariableDType(master_dtype=tf.bfloat16,
                                           slice_dtype=tf.float32,
                                           activation_dtype=tf.bfloat16)
    else:
        variable_dtype = mtf.VariableDType(master_dtype=tf.float32,
                                           slice_dtype=tf.float32,
                                           activation_dtype=tf.float32)

    # Build mtf mesh object
    mesh = mtf.Mesh(graph, "my_mesh", var_placer)

    # Build mtf_features & seq length dict for getting number of microbatches
    # We need to pack inputs into a dict to pass into serialize_training_step
    features_dict = {"inputs": features, "labels": labels}
    sequence_length_dict = {
        "inputs": params["n_ctx"],
        "labels": params["n_ctx"]
    }

    params = add_mode_to_params(params, mode)
    batch_size = get_batch_size(params)

    batch_dim = mtf.Dimension("batch", batch_size)
    batch_dims = [batch_dim]
    feature_length = sequence_length_dict["inputs"]
    length_dim = mtf.Dimension("sequence", feature_length)

    mtf_features = {}
    for key, x in features_dict.items():
        if x is not None:
            feature_shape = mtf.Shape(batch_dims + [length_dim])
            if type(features_dict[key]) == dict:
                features_dict[key] = features_dict[key]["feature"]
            x = tf.cast(features_dict[key], tf.int32)
            x = tf.reshape(x, feature_shape.to_integer_list)
            mtf_features[key] = mtf.import_fully_replicated(mesh,
                                                            x,
                                                            feature_shape,
                                                            name=key)

    # Instantiate dict for dimensions, bias, etc that can be calculated here once then passed into model
    other_features = {}
    memory_length_dim = mtf.Dimension("memory_length", length_dim.size)

    attn_bias = biasmask_attn_weights(
        mesh, length_dim, memory_length_dim,
        variable_dtype) if params["causal"] else None

    # Add attn_bias into mtf_features
    other_features["attn_bias"] = attn_bias

    # Define other Dimensions that we'll need inside the model
    embd_dim = mtf.Dimension("embd", params["n_embd"])
    vocab_dim = mtf.Dimension("vocab", params["n_vocab"])
    # We need this because gathering when both the args have the same dimension in them breaks things
    # This dim is specifically for the weights
    # This prevents the "Einsum has lhs dimension without corresponding rhs or output dimension." error
    embed_sequence_dim = mtf.Dimension("embed_sequence", params["n_ctx"])

    other_features["embd_dim"] = embd_dim
    other_features["vocab_dim"] = vocab_dim
    other_features["embed_sequence_dim"] = embed_sequence_dim
    other_features["memory_length_dim"] = memory_length_dim

    if mode == tf.estimator.ModeKeys.PREDICT:
        # Set up the model for prediction
        inputs = mtf_features["inputs"]
        if params["remove_partial_sequences"] is None:
            params["remove_partial_sequences"] = False

        export = params.get("export", False)

        if not export:
            mtf_samples = sample_autoregressive(
                inputs,
                other_features=other_features,
                params=params,
                variable_dtype=variable_dtype,
                remove_partial_sequences=params["remove_partial_sequences"],
                stop_at_token=params["eos_id"],
                sampling_use_entmax=params['sampling_use_entmax'],
                max_steps=params["predict_max_steps"])

        else:
            with mtf.utils.outside_all_rewrites():
                with tf.variable_scope('gpt2'):
                    mtf_samples, loss, loss_batch = gpt2.model(
                        mtf_features,
                        other_features,
                        params,
                        mesh,
                        variable_dtype=variable_dtype,
                        context=None)

        mtf_samples = mtf.anonymize(mtf_samples)
        inputs = mtf.anonymize(inputs)
        lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True)
        inputs = lowering.export_to_tf_tensor(inputs)
        outputs = lowering.export_to_tf_tensor(mtf_samples)
        predictions = {"inputs": inputs, "outputs": outputs}

        def scaffold_fn():
            return tf.train.Scaffold(
                local_init_op=tf.group(
                    tf.train.Scaffold.default_local_init_op(),
                    lowering.copy_masters_to_slices(),
                    name="mtf_local_init_op"),
                ready_op=tf.concat([
                    tf.report_uninitialized_variables(),
                    resources.report_uninitialized_resources()
                ],
                                   axis=0,
                                   name="mtf_ready_op"))

        return tpu_estimator.TPUEstimatorSpec(
            mode=tf.estimator.ModeKeys.PREDICT,
            predictions=predictions,
            scaffold_fn=scaffold_fn,
            prediction_hooks=[mtf.MtfRestoreHook(lowering)])

    # We're not predicting, so we better be training or evaluating
    assert mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]

    if mode == tf.estimator.ModeKeys.TRAIN:
        # Gets number of microbatches per batch for serialized training
        # if param tokens_per_mb_per_replica = None, this defaults to 1 and no microbatching is performed
        num_microbatches = int(
            mtf_transformer.utils.serialize_num_microbatches(
                batch_dim=batch_dim,
                sequence_length=sequence_length_dict,
                mesh_shape=mesh_shape,
                layout_rules=layout_rules,
                tokens_per_microbatch_per_replica=params[
                    "tokens_per_mb_per_replica"]))
    else:
        num_microbatches = 1

    params[
        "num_microbatches"] = num_microbatches  # Add num microbatches to params

    if num_microbatches > 1:

        # For serialize_training_step we need to modify the model to output results in a dict
        def serialized_fn(mtf_features):
            if params["model"] == "GPT":
                with tf.variable_scope('gpt2'):
                    logits, loss, loss_batch = gpt2.model(
                        mtf_features,
                        other_features,
                        params,
                        mesh,
                        variable_dtype=variable_dtype)
                return {
                    "logits": logits,
                    "loss": loss,
                    "loss_batch": loss_batch
                }
            else:
                raise Exception(
                    f"'{params['model']}' is not a valid model - please select from [GPT]"
                )

        # Serialize the training step - Gradients are accumulated locally and reduced once.
        var_grads, output_dict = mtf.serialize_training_step(
            mtf_features, serialized_fn, batch_dim, num_microbatches)
        loss = output_dict["loss"]
        loss_batch = output_dict["loss_batch"]
        logits = output_dict["logits"]
    else:
        # If we're not splitting into microbatches, return logits & loss as is
        if params["model"] == "GPT":
            with mtf.utils.outside_all_rewrites():
                with tf.variable_scope('gpt2'):
                    logits, loss, loss_batch = gpt2.model(
                        mtf_features,
                        other_features,
                        params,
                        mesh,
                        variable_dtype=variable_dtype,
                        context=None)
        else:
            raise Exception(
                f"'{params['model']}' is not a valid model - please select from [GPT]"
            )

    # Auto layout generation
    if params["auto_layout"]:
        auto_layout(graph, mesh_shape, logits, loss)
    if params["auto_layout_and_mesh_shape"]:
        auto_layout_and_mesh_shape(graph, params["num_cores"], logits, loss)

    if mode == tf.estimator.ModeKeys.TRAIN:
        # In TRAIN mode, get optimizer
        if params["num_microbatches"] > 1:
            # If we are splitting the batch into microbatches, var grads are created in the serialize_training_step fn
            # So we pass them in here
            _, update_ops, var_grads = get_optimizer(
                mesh,
                loss,
                params,
                variable_dtype=variable_dtype,
                inp_var_grads=var_grads)
        else:
            # Otherwise, they are created in the get_optimizer fn, so we leave inp_var_grads blank
            _, update_ops, var_grads = get_optimizer(
                mesh, loss, params, variable_dtype=variable_dtype)
        # Log summaries to tensorboard
        mtf.scalar_summary("loss", loss)
        # Log gradients if in params
        if params["log_grads"] not in [None, False]:
            for g in var_grads:
                grad_norm = mtf.sqrt(mtf.reduce_sum(mtf.square(g)))
                mtf.scalar_summary("grads/norm" + g.name[:-2], grad_norm)
    else:
        # For now, we can only export fully-replicated tensors.
        # This has to be done before lowering or they will not be included in the graph
        mean_logits = mtf.reduce_mean(logits, reduced_dim=vocab_dim)
        max_logits = mtf.argmax(logits, vocab_dim)
        del logits
        fully_replicated_mean_logits = mtf.anonymize(mean_logits)
        fully_replicated_max_logits = mtf.anonymize(max_logits)
        fully_replicated_loss_batch = mtf.anonymize(loss_batch)

    # Gets & prints info about no. trainable vars in the model & dimension names
    get_graph_info(graph)

    # 'lowers' mtf tensors into a tf graph - this enables us to export results as tf tensors
    lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True)
    tf_loss = lowering.export_to_tf_tensor(loss)
    tf_loss = tf.cast(tf_loss, tf.float32)

    if mode == tf.estimator.ModeKeys.TRAIN:
        # Use our patched version until mtf updates theirs
        host_call = create_host_call(params['model_path'])
        mtf.utils.remove_summaries()

        # Creates train_op
        tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
        tf_update_ops.append(tf.assign_add(
            global_step, 1))  # Need to manually increment global_step
        tf.logging.info(f"tf_update_ops: {tf_update_ops}")
        train_op = tf.group(tf_update_ops)
    else:
        tf_mean_logits = lowering.export_to_tf_tensor(
            fully_replicated_mean_logits)
        tf_max_logits = lowering.export_to_tf_tensor(
            fully_replicated_max_logits)
        tf_loss_batch = tf.to_float(
            lowering.export_to_tf_tensor(fully_replicated_loss_batch))

    with mtf.utils.outside_all_rewrites():
        # Copy master variables to slices. Must be called first.
        restore_hook = mtf.MtfRestoreHook(lowering)
        if mode == tf.estimator.ModeKeys.TRAIN:
            # Set up the checkpoint server and return the TPUEstimatorSpec
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=10,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                params["model_path"],
                save_steps=params["steps_per_checkpoint"],
                saver=saver,
                listeners=[saver_listener])

            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                host_call=host_call,
                train_op=train_op,
                training_hooks=[restore_hook, saver_hook])

        elif mode == tf.estimator.ModeKeys.EVAL:
            # Evaluation metrics
            def _perplexity(loss):
                perplexity = tf.exp(loss)
                return tf.metrics.mean(perplexity)

            def _bits_per_byte(loss):
                bpb = loss * (0.29335 / math.log(2))
                return tf.metrics.mean(bpb)

            def _metric_fn(tf_mean_logits, tf_loss_batch):
                mean_logits = tf.metrics.mean(tf_mean_logits)
                loss = tf.reduce_mean(tf_loss_batch)
                perp = _perplexity(loss)
                bpb = _bits_per_byte(loss)
                return {
                    "mean_logits": mean_logits,
                    "perplexity": perp,
                    "bits per byte": bpb
                }

            def _lambada_metric_fn(labels, tf_max_logits, tf_loss_batch):
                eos_token = params["eos_id"]
                answer_positions = tf.where(
                    tf.math.not_equal(labels, eos_token))

                correct_answers = tf.gather_nd(
                    tf.math.equal(tf_max_logits, labels), answer_positions)
                accuracy = tf.metrics.mean(tf.cast(correct_answers,
                                                   tf.float32))

                # I guess tf_loss_batch has z_loss and maybe other stuff added to it
                # so maybe this should be calculated separately in the future
                answer_loss = tf.gather_nd(tf_loss_batch, answer_positions)
                log_perplexity = tf.metrics.mean(answer_loss)

                return {
                    "lambada_acc": accuracy,
                    "lambada_log_ppl": log_perplexity
                }

            eval_task = params["eval_task"]
            if eval_task == "lambada":
                eval_metrics = (_lambada_metric_fn,
                                [labels, tf_max_logits, tf_loss_batch])
            else:
                eval_metrics = (_metric_fn, [tf_mean_logits, tf_loss_batch])

            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.EVAL,
                evaluation_hooks=[restore_hook],
                loss=tf_loss,
                eval_metrics=eval_metrics)
示例#16
0
文件: utils.py 项目: appcoreopc/mesh
def run(tpu_job_name,
        data_dir,
        master_dtype,
        slice_dtype,
        activation_dtype,
        tpu,
        gcp_project,
        tpu_zone,
        autostack,
        model_dir,
        mode=gin.REQUIRED,
        iterations_per_loop=gin.REQUIRED,
        save_checkpoints_steps=gin.REQUIRED,
        eval_steps=gin.REQUIRED,
        train_steps=gin.REQUIRED,
        batch_size=gin.REQUIRED,
        text2self=gin.REQUIRED,
        dataset=gin.REQUIRED):
    """Run training/eval/inference.

  Args:
    tpu_job_name: string, name of TPU worker binary
    data_dir: string, data_dir for TensorFlow Datasets
    master_dtype: string, datatype for checkpoints
    slice_dtype: string, datatype for variables in memory
    activation_dtype: string, datatype for activations
    tpu: string, the Cloud TPU to use for training
    gcp_project: string, project name for the Cloud TPU-enabled project
    tpu_zone: string, GCE zone where the Cloud TPU is located in
    autostack: boolean, internally combine variables
    model_dir: string, estimator model_dir
    mode: string, train/evaluate/infer
    iterations_per_loop: integer, steps per train loop
    save_checkpoints_steps: integer, steps per checkpoint
    eval_steps: integer, number of evaluation steps
    train_steps: Total number of training steps.
    batch_size: Mini-batch size for the training. Note that this is the global
      batch size and not the per-shard batch.
    text2self: Whether to train a language model (True) or encoder-decoder
      text-to-text model (False).
    dataset: TensorFlow Datasets dataset name.
  """
    cluster = tf.contrib.cluster_resolver.TPUClusterResolver(
        tpu if (tpu) else "", zone=tpu_zone, project=gcp_project)

    my_tpu_config = tpu_config.TPUConfig(
        tpu_job_name=tpu_job_name,
        iterations_per_loop=iterations_per_loop,
        num_cores_per_replica=1,
        per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST,
    )

    run_config = tpu_config.RunConfig(
        cluster=cluster,
        model_dir=model_dir,
        save_checkpoints_steps=save_checkpoints_steps,
        tpu_config=my_tpu_config)

    dataset = transformer_dataset.TokenizedTFDSDataset(dataset,
                                                       text2self=text2self,
                                                       data_dir=data_dir
                                                       or None)

    output_encoder = dataset.encoders["targets"]
    if text2self:
        input_encoder = output_encoder
    else:
        input_encoder = dataset.encoders["inputs"]

    transformer_model = model(
        input_vocab_size=transformer_dataset.padded_vocab_size(
            input_encoder.vocab_size, 128),
        output_vocab_size=transformer_dataset.padded_vocab_size(
            output_encoder.vocab_size, 128),
        text2self=text2self)
    mesh_shape = mtf.convert_to_shape(gin.query_parameter("model.mesh_shape"))
    layout_rules = mtf.convert_to_layout_rules(
        gin.query_parameter("model.layout"))
    # Data-types used for variables and activations
    # See comments in the FLAGS
    master_dtype = tf.as_dtype(master_dtype)
    if slice_dtype:
        slice_dtype = tf.as_dtype(slice_dtype)
    elif not tpu or mode == "train":
        slice_dtype = tf.float32
    else:
        slice_dtype = tf.bfloat16
    if activation_dtype:
        activation_dtype = tf.as_dtype(activation_dtype)
    else:
        activation_dtype = tf.bfloat16 if tpu else tf.float32
    variable_dtype = mtf.VariableDType(master_dtype=master_dtype,
                                       slice_dtype=slice_dtype,
                                       activation_dtype=activation_dtype)

    length_from_config = gin.query_parameter(
        "model.length") or gin.query_parameter("model.max_length")

    model_fn = tpu_estimator_model_fn(transformer_model=transformer_model,
                                      model_dir=model_dir,
                                      use_tpu=tpu,
                                      mesh_shape=mesh_shape,
                                      layout_rules=layout_rules,
                                      text2self=text2self,
                                      variable_dtype=variable_dtype,
                                      batch_size=batch_size,
                                      length=length_from_config,
                                      autostack=autostack)

    estimator = tpu_estimator.TPUEstimator(model_fn=model_fn,
                                           config=run_config,
                                           train_batch_size=batch_size,
                                           eval_batch_size=batch_size,
                                           predict_batch_size=batch_size,
                                           use_tpu=tpu,
                                           export_to_tpu=False,
                                           params={})

    def input_fn(params):
        del params
        return dataset.load(batch_size=batch_size,
                            length=length_from_config,
                            train=(mode == "train"),
                            pack=True)

    if mode == "train":
        estimator.train(input_fn=input_fn, max_steps=train_steps)
    elif mode == "evaluate":
        estimator.evaluate(
            input_fn=input_fn,
            steps=eval_steps,
        )
    elif mode == "infer":
        decode_from_file(estimator,
                         batch_size=batch_size,
                         length=length_from_config,
                         inputs_encoder=dataset.
                         encoders["targets" if text2self else "inputs"],
                         targets_encoder=dataset.encoders["targets"],
                         text2self=text2self)
    else:
        raise ValueError("unknown mode %s - must be train/evaluate/infer" %
                         mode)
示例#17
0
    def estimator_model_fn(cls,
                           hparams,
                           features,
                           labels,
                           mode,
                           config=None,
                           params=None,
                           decode_hparams=None,
                           use_tpu=False):
        hparams = copy.deepcopy(hparams)
        hparams.use_tpu = use_tpu
        # merge decode_hparams into hparams if present
        if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None:
            for k, v in six.iteritems(decode_hparams.values()):
                if hasattr(hparams, k) and getattr(hparams, k) != v:
                    tf.logging.warning(
                        "Overriding hparams.%s with %s from decode_hparams" %
                        (k, v))
                setattr(hparams, k, v)

        # Instantiate model
        data_parallelism = None
        if not use_tpu and config:
            data_parallelism = config.data_parallelism
        model = cls(hparams,
                    mode,
                    data_parallelism=data_parallelism,
                    decode_hparams=decode_hparams)

        global_step = tf.train.get_global_step()

        mesh_shape = mtf.convert_to_shape(hparams.mesh_shape)
        layout_rules = mtf.convert_to_layout_rules(hparams.layout)
        if use_tpu:
            ctx = params["context"]
            num_hosts = ctx.num_hosts
            host_placement_fn = ctx.tpu_host_placement_function
            device_list = [
                host_placement_fn(host_id=t) for t in range(num_hosts)
            ]
            # TODO(ylc): Better estimation of replica cache size?
            replica_cache_size = 300 * 1000000  # 300M per replica
            # Worker 0 caches all the TPU binaries.
            worker0_mem = replica_cache_size * ctx.num_replicas
            devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
            var_placer = mtf.utils.BalancedVariablePlacer(
                device_list, devices_memeory_usage)
            mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape, layout_rules, mesh_devices, ctx.device_assignment)
        else:
            var_placer = None
            if data_parallelism is None or len(
                    data_parallelism.ps_devices) == 1:
                mesh_devices = [""] * mesh_shape.size
            else:
                assert len(data_parallelism.ps_devices) == mesh_shape.size
                mesh_devices = data_parallelism.ps_devices
            mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, layout_rules, mesh_devices)

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh", var_placer)
        # PREDICT mode
        if mode == tf.estimator.ModeKeys.PREDICT:
            return model.estimator_spec_predict(features, mesh, mesh_impl,
                                                use_tpu)

        logits, loss = model.mtf_model_fn(features, mesh)
        if use_tpu and logits is not None:
            logits = mtf.anonymize(logits)

        # TRAIN mode
        if mode == tf.estimator.ModeKeys.TRAIN:
            var_grads = mtf.gradients(
                [loss], [v.outputs[0] for v in graph.trainable_variables])
            lr = learning_rate.learning_rate_schedule(hparams)
            tf.summary.scalar("learning_rate", lr)
            mtf_lr = mtf.import_tf_tensor(
                mesh, tf.convert_to_tensor(lr, dtype=tf.float32),
                mtf.Shape([]))
            optimizer = mtf.optimize.make_optimizer(hparams, mtf_lr)
            update_ops = []
            for grad, var in zip(var_grads, graph.trainable_variables):
                update_ops.extend(optimizer.apply_grad(grad, var))

        lowering = mtf.Lowering(graph, {mesh: mesh_impl})

        tf_loss = lowering.export_to_tf_tensor(loss)
        tf_loss = tf.to_float(tf_loss)
        if logits and mode != tf.estimator.ModeKeys.TRAIN:
            tf_logits = lowering.export_to_tf_tensor(logits)

        if mode == tf.estimator.ModeKeys.TRAIN:
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            # tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
            train_op = tf.group(tf_update_ops)

        with mtf.utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=10,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                hparams.model_dir,
                save_steps=1000,
                saver=saver,
                listeners=[saver_listener])

        # EVAL mode
        if mode == tf.estimator.ModeKeys.EVAL:
            tf_logits = lowering.export_to_tf_tensor(logits)
            return model.estimator_spec_eval(features, tf_logits, labels,
                                             tf_loss, restore_hook, use_tpu)

        if use_tpu:
            # TPU host call. Important: need to be called before remove_summaries()
            if hparams.tpu_enable_host_call:
                host_call = t2t_model.create_host_call(hparams.model_dir)
            else:
                host_call = None

            t2t_model.remove_summaries()
            return tpu_estimator.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                train_op=train_op,
                host_call=host_call,
                training_hooks=[restore_hook, saver_hook])
        else:
            return tf.estimator.EstimatorSpec(
                tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                train_op=train_op,
                training_chief_hooks=[restore_hook, saver_hook])
示例#18
0
  def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
    """The `model_fn` for TPUEstimator."""

    tf.logging.info("*** Features ***")
    for name in sorted(features.keys()):
      tf.logging.info("  name = %s, shape = %s" % (name, features[name].shape))

    # MTF setup.
    graph = mtf.Graph()
    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)

    ctx = params["context"]
    num_hosts = ctx.num_hosts
    host_placement_fn = ctx.tpu_host_placement_function
    device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
    tf.logging.info("device_list = %s" % device_list,)
    replica_cache_size = 300 * 1000000  # 300M per replica
    # Worker 0 caches all the TPU binaries.
    worker0_mem = replica_cache_size * ctx.num_replicas
    devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
    var_placer = mtf.utils.BalancedVariablePlacer(device_list,
                                                  devices_memeory_usage)
    mesh_devices = [""] * mesh_shape.size
    mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules,
                                                mesh_devices,
                                                ctx.device_assignment)
    mesh = mtf.Mesh(graph, "bert_mesh", var_placer)

    unique_ids = features["unique_ids"]
    input_ids = features["input_ids"]
    input_mask = features["input_mask"]
    segment_ids = features["segment_ids"]

    batch_size = input_ids.get_shape()[0].value
    batch_dim = mtf.Dimension("batch", batch_size)
    seq_length = input_ids.get_shape()[1].value
    seq_dim = mtf.Dimension("seq", seq_length)

    mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids, [batch_dim, seq_dim])
    mtf_input_mask = mtf.import_tf_tensor(mesh, input_mask,
                                          [batch_dim, seq_dim])
    mtf_segment_ids = mtf.import_tf_tensor(mesh, segment_ids,
                                           [batch_dim, seq_dim])

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)

    (start_logits, end_logits) = create_model(
        bert_config=bert_config,
        is_training=is_training,
        input_ids=mtf_input_ids,
        input_mask=mtf_input_mask,
        segment_ids=mtf_segment_ids)

    if mode == tf.estimator.ModeKeys.TRAIN:

      def compute_loss(logits, positions):
        one_hot_positions = mtf.one_hot(positions, output_dim=seq_dim)
        log_probs = mtf.log_softmax(logits, seq_dim)
        loss = -mtf.reduce_mean(
            mtf.reduce_sum(one_hot_positions * log_probs, reduced_dim=seq_dim))
        return loss

      start_positions = features["start_positions"]
      mtf_start_positions = mtf.import_tf_tensor(mesh, start_positions,
                                                 [batch_dim])
      end_positions = features["end_positions"]
      mtf_end_positions = mtf.import_tf_tensor(mesh, end_positions, [batch_dim])

      start_loss = compute_loss(start_logits, mtf_start_positions)
      end_loss = compute_loss(end_logits, mtf_end_positions)

      total_loss = (start_loss + end_loss) / 2.0
      _, update_ops = optimization_lib.create_optimizer(
          total_loss,
          learning_rate,
          num_train_steps,
          num_warmup_steps,
          max_optimized_variable_size=FLAGS.max_optimized_variable_size,
          optimizer=FLAGS.optimizer,
          clip_gradients=FLAGS.clip_gradients)
    elif mode == tf.estimator.ModeKeys.PREDICT:
      start_logits = mtf.anonymize(start_logits)
      end_logits = mtf.anonymize(end_logits)

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})

    if mode == tf.estimator.ModeKeys.TRAIN:
      tf_loss = tf.to_float(lowering.export_to_tf_tensor(total_loss))
      global_step = tf.train.get_global_step()
      tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
      tf_update_ops.append(tf.assign_add(global_step, 1))
      tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
      train_op = tf.group(tf_update_ops)

    tvars = tf.trainable_variables()
    initialized_variable_names = {}
    scaffold_fn = None
    if init_checkpoint:
      (assignment_map, initialized_variable_names
      ) = bert_lib.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
      if use_tpu:

        def tpu_scaffold():
          tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
          return tf.train.Scaffold()

        scaffold_fn = tpu_scaffold
      else:
        tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

    tf.logging.info("**** Trainable Variables ****")
    for var in tvars:
      init_string = ""
      if var.name in initialized_variable_names:
        init_string = ", *INIT_FROM_CKPT*"
      tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                      init_string)

    with mtf.utils.outside_all_rewrites():
      # Copy master variables to slices. Must be called first.
      restore_hook = mtf.MtfRestoreHook(lowering)
      if mode == tf.estimator.ModeKeys.TRAIN:
        saver = tf.train.Saver(
            tf.global_variables(),
            sharded=True,
            max_to_keep=10,
            keep_checkpoint_every_n_hours=2,
            defer_build=False,
            save_relative_paths=True)
        tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
        saver_listener = mtf.MtfCheckpointSaverListener(lowering)
        saver_hook = tf.train.CheckpointSaverHook(
            FLAGS.output_dir,
            save_steps=1000,
            saver=saver,
            listeners=[saver_listener])

        return tf.estimator.tpu.TPUEstimatorSpec(
            mode,
            loss=tf_loss,
            train_op=train_op,
            training_hooks=[restore_hook, saver_hook],
            scaffold_fn=scaffold_fn)
      elif mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            "unique_ids": unique_ids,
            "start_logits": lowering.export_to_tf_tensor(start_logits),
            "end_logits": lowering.export_to_tf_tensor(end_logits),
        }

        return tf.estimator.tpu.TPUEstimatorSpec(
            mode,
            prediction_hooks=[restore_hook],
            predictions=predictions,
            scaffold_fn=scaffold_fn)
      else:
        raise ValueError("Only TRAIN and PREDICT modes are supported: %s" %
                         (mode))
示例#19
0
def model_fn(features, labels, mode, params):
    """A model is called by TpuEstimator."""
    del labels
    global_step = tf.train.get_global_step()
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, 'my_mesh')
    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    mesh_devices = [''] * mesh_shape.size
    mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
        mesh_shape, mtf.convert_to_layout_rules(FLAGS.layout), mesh_devices,
        params['context'].device_assignment)
    with mtf.utils.outside_all_rewrites():
        logits, loss = toy_model(features, mesh)

    # TRAIN mode
    if mode == tf.estimator.ModeKeys.TRAIN:
        var_grads = mtf.gradients(
            [loss], [v.outputs[0] for v in graph.trainable_variables])
        optimizer = mtf.optimize.AdafactorOptimizer()
        update_ops = []
        for grad, var in zip(var_grads, graph.trainable_variables):
            update_ops.extend(optimizer.apply_grad(grad, var))
    else:
        # for now, we can only export fully-replicated tensors.
        fully_replicated_logits = mtf.anonymize(logits)

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})

    tf_loss = lowering.export_to_tf_tensor(loss)

    if mode == tf.estimator.ModeKeys.TRAIN:
        tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
        tf_update_ops.append(tf.assign_add(global_step, 1))
        tf.logging.info('tf_update_ops: {}'.format(tf_update_ops))
        train_op = tf.group(tf_update_ops)
    else:
        tf_logits = lowering.export_to_tf_tensor(fully_replicated_logits)

    with mtf.utils.outside_all_rewrites():
        # Copy master variables to slices. Must be called first.
        restore_hook = mtf.MtfRestoreHook(lowering)
        if mode == tf.estimator.ModeKeys.TRAIN:
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=10,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                FLAGS.model_dir,
                save_steps=1000,
                saver=saver,
                listeners=[saver_listener])

            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                train_op=train_op,
                training_hooks=[restore_hook, saver_hook])
        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(tf_logits):
                mean_logitss = tf.metrics.mean(tf_logits)
                return {'mean_logitss': mean_logitss}

            eval_metrics = (metric_fn, [tf_logits])

            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.EVAL,
                evaluation_hooks=[restore_hook],
                loss=tf_loss,
                eval_metrics=eval_metrics)
示例#20
0
 def testConvertToLayoutRules(self, inputs):
     layout_rules = mtf.convert_to_layout_rules(inputs)
     self.assertEqual(
         layout_rules._pairs,
         mtf.LayoutRules([("d_ff", "model"), ("heads", "model")])._pairs)
示例#21
0
文件: mnist.py 项目: qixiuai/mesh
def model_fn(features, labels, mode, params):
    """The model_fn argument for creating an Estimator."""
    tf.logging.info("features = %s labels = %s mode = %s params=%s" %
                    (features, labels, mode, params))
    global_step = tf.train.get_global_step()
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    logits, loss = mnist_model(features, labels, mesh)
    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)
    mesh_size = mesh_shape.size
    mesh_devices = [""] * mesh_size
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        mesh_shape, layout_rules, mesh_devices)

    if mode == tf.estimator.ModeKeys.TRAIN:
        var_grads = mtf.gradients(
            [loss], [v.outputs[0] for v in graph.trainable_variables])
        optimizer = mtf.optimize.AdafactorOptimizer()
        update_ops = optimizer.apply_grads(var_grads,
                                           graph.trainable_variables)

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    restore_hook = mtf.MtfRestoreHook(lowering)

    tf_logits = lowering.export_to_tf_tensor(logits)
    if mode != tf.estimator.ModeKeys.PREDICT:
        tf_loss = lowering.export_to_tf_tensor(loss)
        tf.summary.scalar("loss", tf_loss)

    if mode == tf.estimator.ModeKeys.TRAIN:
        tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
        tf_update_ops.append(tf.assign_add(global_step, 1))
        train_op = tf.group(tf_update_ops)
        saver = tf.train.Saver(tf.global_variables(),
                               sharded=True,
                               max_to_keep=10,
                               keep_checkpoint_every_n_hours=2,
                               defer_build=False,
                               save_relative_paths=True)
        tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
        saver_listener = mtf.MtfCheckpointSaverListener(lowering)
        saver_hook = tf.train.CheckpointSaverHook(FLAGS.model_dir,
                                                  save_steps=1000,
                                                  saver=saver,
                                                  listeners=[saver_listener])

        accuracy = tf.metrics.accuracy(labels=labels,
                                       predictions=tf.argmax(tf_logits,
                                                             axis=1))

        # Name tensors to be logged with LoggingTensorHook.
        tf.identity(tf_loss, "cross_entropy")
        tf.identity(accuracy[1], name="train_accuracy")

        # Save accuracy scalar to Tensorboard output.
        tf.summary.scalar("train_accuracy", accuracy[1])

        # restore_hook must come before saver_hook
        return tf.estimator.EstimatorSpec(
            tf.estimator.ModeKeys.TRAIN,
            loss=tf_loss,
            train_op=train_op,
            training_chief_hooks=[restore_hook, saver_hook])

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            "classes": tf.argmax(tf_logits, axis=1),
            "probabilities": tf.nn.softmax(tf_logits),
        }
        return tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.PREDICT,
            predictions=predictions,
            prediction_hooks=[restore_hook],
            export_outputs={
                "classify": tf.estimator.export.PredictOutput(predictions)
            })
    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.EVAL,
            loss=tf_loss,
            evaluation_hooks=[restore_hook],
            eval_metric_ops={
                "accuracy":
                tf.metrics.accuracy(labels=labels,
                                    predictions=tf.argmax(tf_logits, axis=1)),
            })
示例#22
0
 def testConvertToLayoutRulesGenericInputs(self):
     with self.assertRaises(ValueError):
         mtf.convert_to_layout_rules("d_ff;heads")
示例#23
0
    def __init__(
            self,
            model_dir,
            tpu,
            tpu_job_name=None,
            tpu_zone=None,
            gcp_project=None,
            tpu_topology="2x2",
            model_parallelism=8,
            batch_size=("tokens_per_batch", 1024),
            sequence_length=None,
            model_type="bitransformer",
            layout_rules="ensemble:ensemble,batch:batch,d_ff:model,heads:model,vocab:model,experts:batch",
            autostack=True,
            learning_rate_schedule=None,
            keep_checkpoint_max=None,
            save_checkpoints_steps=5000,
            optimizer=None,
            predict_fn=None,
            variable_filter=None,
            ensemble_inputs=None,
            iterations_per_loop=100):
        """Constructor for MtfModel class.

    Args:
      model_dir: str, directory to save the model.
      tpu: str, the TPU address to use.
      tpu_job_name: str, name of the TPU worker binary.
      tpu_zone: str, GCE zone where the Cloud TPU is located
      gcp_project: str, project name for the Cloud TPU-enabled project.
      tpu_topology: str, e.g. "2x2".
      model_parallelism: integer, the number of cores per model replica.
      batch_size: An integer or a (method, value) pair to pass to
        compute_batch_size(). Note that this is the global batch size and not
        the per-shard batch size.
      sequence_length: an integer or a dict from feature-key to integer
        the (packed) sequence length, e.g. {"inputs": 512, "targets": 128}
      model_type: str, a model type from mesh tf models.
      layout_rules: an input to mtf.convert_to_layout_rules()
      autostack: boolean, internally combine variables.
      learning_rate_schedule: an optional function taking the scalar name
        argument `step` and the numeric argument `total_train_steps` and return
        the scalar learning rate.
      keep_checkpoint_max: an integer, maximum number of checkpoints to keep.
      save_checkpoints_steps: an integer, steps per checkpoint.
      optimizer: a class extending optimize.Optimizer, required for training.
      predict_fn: an optional function that can be used to override the default
        transformer prediction behavior. Must return a tensor of shape
        [batch_dim, length_dim] that will be the prediction for each example.
        Must accept the following arguments:
          - model: a Unitransformer or Bitransformer
          - features: a dict representing an example. Every value will be an
            mtf.Tensor with shape [batch_dim, length_dim].
          - variable_dtype: an mtf.VariableDType
      variable_filter: a str, a variable will only be trained if its name
        matches this regex. If None (default), train all trainable variables.
      ensemble_inputs: an integer, see `train_model` docstring for details.
      iterations_per_loop: integer, steps per train loop
    """

        mesh_shape = utils.tpu_mesh_shape(tpu_topology, model_parallelism)

        sequence_length = sequence_length or {"inputs": 512, "targets": 512}

        if isinstance(sequence_length, int):
            sequence_length = {
                "inputs": sequence_length,
                "targets": sequence_length
            }

        if not isinstance(batch_size, int):
            self._batch_size = utils.compute_batch_size(
                sequence_length, mesh_shape, layout_rules, batch_size)
        else:
            self._batch_size = batch_size

        self._learning_rate_schedule = (
            learning_rate_schedule
            or learning_rate_schedules.learning_rate_schedule_noam)

        self._optimizer = optimizer or optimize.AdafactorOptimizer

        self._sequence_length = sequence_length
        self._model_dir = model_dir
        self._model_type = model_type
        self._ensemble_inputs = ensemble_inputs

        self._layout_rules = mtf.convert_to_layout_rules(layout_rules)
        self._mesh_shape = mtf.convert_to_shape(mesh_shape)

        self._autostack = autostack
        self._keep_checkpoint_max = keep_checkpoint_max
        self._save_checkpoints_steps = save_checkpoints_steps
        self._predict_fn = predict_fn
        self._variable_filter = variable_filter
        self._ensemble_inputs = ensemble_inputs
        self._iterations_per_loop = iterations_per_loop

        self._cluster = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu if (tpu) else "", zone=tpu_zone, project=gcp_project)
        self._tpu = tpu
        self._tpu_job_name = tpu_job_name
        self._estimator = None
示例#24
0
文件: toy.py 项目: NeuroArchitect/lm
    def __call__(self, features, labels, mode, params):  # this is the model_fn
        """A model is called by TpuEstimator."""
        del labels
        global_step = tf.train.get_global_step()

        # Graph setup
        graph = mtf.Graph()
        mesh_shape = mtf.convert_to_shape(self.mesh_shape)
        layout_rules = mtf.convert_to_layout_rules(self.layout)
        if params["use_tpu"]:
            ctx = params["context"]
            num_hosts = ctx.num_hosts
            host_placement_fn = ctx.tpu_host_placement_function
            device_list = [
                host_placement_fn(host_id=t) for t in range(num_hosts)
            ]
            # Worker 0 caches all the TPU binaries.
            replica_cache_size = 300 * 1024 * 1024  # 300M per replica.
            worker0_mem = replica_cache_size * 8 * num_hosts
            devices_memory_usage = [worker0_mem] + [0] * (num_hosts - 1)
            var_placer = mtf.utils.BalancedVariablePlacer(
                device_list, devices_memory_usage)
            mesh = mtf.Mesh(graph, "my_mesh", var_placer)
            mesh_devices = [""] * mesh_shape.size

            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape, layout_rules, mesh_devices, devices_memory_usage)
        else:
            var_placer = None
            mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, layout_rules, mesh_devices)

            mesh = mtf.Mesh(graph, "my_mesh", var_placer)

        # RUN Model
        with mtf.utils.outside_all_rewrites():
            logits, loss = self.model(mesh, features, params)

        # TRAIN mode
        if mode == tf.estimator.ModeKeys.TRAIN:
            var_grads = mtf.gradients(
                [loss], [v.outputs[0] for v in graph.trainable_variables])
            if self.optimizer == "Adafactor":
                optimizer = mtf.optimize.AdafactorOptimizer()
            else:
                assert self.optimizer == "SGD"
                optimizer = mtf.optimize.SgdOptimizer(
                    learning_rate=self.learning_rate)
                update_ops = optimizer.apply_grads(var_grads,
                                                   graph.trainable_variables)
        else:
            # for now, we can only export fully-replicated tensors.
            fully_replicated_logits = mtf.anonymize(logits)

        # covert back to tensorflow format
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        tf_loss = tf.to_float(lowering.export_to_tf_tensor(loss))
        if mode == tf.estimator.ModeKeys.TRAIN:
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
            train_op = tf.group(tf_update_ops)
        else:
            tf_logits = lowering.export_to_tf_tensor(fully_replicated_logits)

        # create estimator
        with mtf.utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            if mode == tf.estimator.ModeKeys.TRAIN:
                saver = tf.train.Saver(
                    tf.global_variables(),
                    sharded=True,
                    max_to_keep=10,
                    keep_checkpoint_every_n_hours=2,
                    defer_build=False,
                    save_relative_paths=True,
                )
                tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
                saver_listener = mtf.MtfCheckpointSaverListener(lowering)
                saver_hook = tf.train.CheckpointSaverHook(
                    self.model_dir,
                    save_steps=1000,
                    saver=saver,
                    listeners=[saver_listener],
                )

                return tpu_estimator.TPUEstimatorSpec(
                    tf.estimator.ModeKeys.TRAIN,
                    loss=tf_loss,
                    train_op=train_op,
                    training_hooks=[restore_hook, saver_hook],
                )
            elif mode == tf.estimator.ModeKeys.EVAL:

                def metric_fn(tf_logits):
                    mean_logits = tf.metrics.mean(tf_logits)
                    return {"mean_logits": mean_logits}

                eval_metrics = (metric_fn, [tf_logits])
                return tpu_estimator.TPUEstimatorSpec(
                    tf.estimator.ModeKeys.EVAL,
                    evaluation_hooks=[restore_hook],
                    loss=tf_loss,
                    eval_metrics=eval_metrics,
                )
            elif mode == tf.estimator.ModeKeys.PREDICT:
                return tpu_estimator.TPUEstimatorSpec(
                    tf.estimator.ModeKeys.PREDICT,
                    evaluation_hooks=[restore_hook],
                    loss=None,
                    eval_metrics=eval_metrics,
                )

        @property
        def dense_initializer(self):
            if self.config.initializer_range:
                return tf.truncated_normal_initializer(
                    stddev=self.config.initializer_range)
            else:
                return mtf.layers.VarianceScalingInitializer(scale=0.4)

        @property
        def embedding_initializer(self):
            initializer = self.dense_initializer
            if isinstance(initializer, mtf.layers.DenseInitializer):
                # embedding matrix is also used as classifier weight matrix.
                # scale it appropriately.
                return initializer(reduced_dims=[self.model_dim],
                                   new_dims=[self.vocab_dim])
            else:
                return initializer

        @property
        def num_hidden_layers(self):
            return self.config.num_hidden_layers

        def normalize(self, x, reduce_dim):
            return nn.layer_norm(
                x,
                reduce_dim,
                subtract_mean=self.config.use_bias,
                use_bias=self.config.use_bias,
            )

        def model(self, mesh, x, y, params):
            # x :: [batch, io, vocab]

            if params["precision"] == "bfloat16":
                dtype = tf.bfloat16
                # master has type float32, slice and activation have type bfloat16
                variable_dtype = mtf.VariableDType(tf.float32, tf.bfloat16,
                                                   tf.bfloat16)
            else:
                dtype = tf.float32
                # master, slice and activate have all float16
                variable_dtype = mtf.VariableDType(tf.float32, tf.float32,
                                                   tf.float32)

            # Build the actual model
            batch_dim = mtf.Dimension("batch", params["batch_size"])
            vocab_dim = mtf.Dimension("vocab", params["vocab_size"])
            io_dim = mtf.Dimension("sequence", params["io"])
            io_chan_dim = mtf.Dimension("io", params["io_channels"])

            # from input to mtf
            x = mtf.import_tf_tensor(mesh, x,
                                     mtf.Shape([batch_dim, io_dim, vocab_dim]))

            # Embeddings
            with tf.variable_scope(scope="toy", default_name="seq2seq"):
                with tf.variable_scope("embeddings"):
                    # Perform embedding lookup on the word ids.
                    embedding_table = mtf.get_variable(
                        mesh,
                        "word_embeddings",
                        mtf.Shape([vocab_dim, io_chan_dim]),
                        initializer=self.embedding_initializer,
                    )

                    word_embedding_output = mtf.gather(
                        embedding_table,
                        x,
                        dim=vocab_dim,
                        output_shape=io_chan_dim)

                    # Add positional embeddings and token type embeddings, then layer
                    # normalize and perform dropout.
                    embedding_output = word_embedding_output

                    pos_embedding = mtf.get_variable(
                        mesh,
                        "pos_embeddings",
                        mtf.Shape([io_dim, io_chan_dim]),
                        initializer=self.embedding_initializer,
                    )
                    embedding_output = self.normalize(embedding_output)
                    embedding_output = mtf.dropout(
                        embedding_output,
                        keep_prob=1.0 - self.config.layer_output_dropout_prob,
                    )

                # shift token by pos embeddings
                x = word_embedding_output + pos_embedding
                x = mtf.cast(x, variable_dtype.activation_dtype)

                h = x
                for lnum in range(1, self.num_hidden_layers + 2):
                    if lnum + 1 == self.num_hidden_layers + 2:
                        # output layer
                        dim = io_dim
                    elif lnum % 2 == 0:
                        dim = mtf.Dimension("hidden_even", io_chan_dim)
                    else:
                        dim = mtf.Dimension("hidden_odd", io_chan_dim)
                        h = mtf.layers.dense(
                            h,
                            dim,
                            use_bias=False,
                            master_dtype=variable_dtype.master_dtype,
                            slice_dtype=variable_dtype.slice_dtype,
                            name="layer_%d" % lnum,
                        )

                prediction = h
                # project back to token dimensions

                # compute the mean quare loss between the input and the output
                loss = mtf.reduce_mean(mtf.square(y - prediction))
                return prediction, loss