示例#1
0
  def estimator_spec_predict(self, features, mesh, mesh_impl, use_tpu):
    mtf_samples = mtf.anonymize(self.sample(features, mesh))
    lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl})
    outputs = lowering.export_to_tf_tensor(mtf_samples)
    if self.has_input:
      ndims = len(outputs.shape.as_list())
      actual_batch_size = tf.shape(features["inputs"])[0]
      outputs = tf.slice(
          outputs, [0] * ndims, [actual_batch_size] + [-1] * (ndims - 1))
    predictions = {
        "outputs": outputs
    }
    if features.get("infer_targets") is not None:
      predictions["infer_targets"] = features["infer_targets"]

    if features.get("inputs") is not None:
      predictions["inputs"] = features["inputs"]

    if use_tpu:
      t2t_model.remove_summaries()
      return tpu_estimator.TPUEstimatorSpec(
          mode=tf.estimator.ModeKeys.PREDICT,
          predictions=predictions,
          prediction_hooks=[mtf.MtfRestoreHook(lowering)])
    else:
      return tf.estimator.EstimatorSpec(
          tf.estimator.ModeKeys.PREDICT,
          predictions=predictions,
          prediction_hooks=[mtf.MtfRestoreHook(lowering)])
示例#2
0
    def estimator_spec_predict(self, features, mesh, mesh_impl, use_tpu):
        mtf_samples = mtf.anonymize(self.sample(features, mesh))
        lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl})
        outputs = lowering.export_to_tf_tensor(mtf_samples)
        if self.has_input:
            ndims = len(outputs.shape.as_list())
            actual_batch_size = tf.shape(features["inputs"])[0]
            outputs = tf.slice(outputs, [0] * ndims,
                               [actual_batch_size] + [-1] * (ndims - 1))
        predictions = {"outputs": outputs}
        if features.get("infer_targets") is not None:
            predictions["infer_targets"] = features["infer_targets"]

        if features.get("inputs") is not None:
            predictions["inputs"] = features["inputs"]

        if use_tpu:
            t2t_model.remove_summaries()
            return tpu_estimator.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.PREDICT,
                predictions=predictions,
                prediction_hooks=[mtf.MtfRestoreHook(lowering)])
        else:
            return tf.estimator.EstimatorSpec(
                tf.estimator.ModeKeys.PREDICT,
                predictions=predictions,
                prediction_hooks=[mtf.MtfRestoreHook(lowering)])
示例#3
0
def model_fn(features, labels, mode, params):
  """A model is called by TpuEstimator."""
  del labels
  global_step = tf.train.get_global_step()
  graph = mtf.Graph()
  mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
  layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)
  if FLAGS.use_tpu:
    ctx = params['context']
    num_hosts = ctx.num_hosts
    host_placement_fn = ctx.tpu_host_placement_function
    device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
    tf.logging.info('device_list = %s' % device_list,)
    # TODO(ylc): Better estimation of replica cache size?
    replica_cache_size = 300 * 1000000  # 300M per replica
    # Worker 0 caches all the TPU binaries.
    worker0_mem = replica_cache_size * ctx.num_replicas
    devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
    var_placer = mtf.utils.BalancedVariablePlacer(device_list,
                                                  devices_memeory_usage)
    mesh_devices = [''] * mesh_shape.size
    mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
        mesh_shape, layout_rules, mesh_devices, ctx.device_assignment)
  else:
    var_placer = None
    mesh_devices = [''] * mesh_shape.size
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        mesh_shape, layout_rules, mesh_devices)
  mesh = mtf.Mesh(graph, 'my_mesh', var_placer)

  with mtf.utils.outside_all_rewrites():
    logits, loss = toy_model(features, mesh)

  # TRAIN mode
  if mode == tf.estimator.ModeKeys.TRAIN:
    var_grads = mtf.gradients([loss],
                              [v.outputs[0] for v in graph.trainable_variables])
    if FLAGS.optimizer == 'Adafactor':
      optimizer = mtf.optimize.AdafactorOptimizer()
    else:
      assert FLAGS.optimizer == 'SGD'
      optimizer = mtf.optimize.SgdOptimizer(lr=FLAGS.lr)
    update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables)
  else:
    # for now, we can only export fully-replicated tensors.
    fully_replicated_logits = mtf.anonymize(logits)

  lowering = mtf.Lowering(graph, {mesh: mesh_impl})

  tf_loss = tf.to_float(lowering.export_to_tf_tensor(loss))

  if mode == tf.estimator.ModeKeys.TRAIN:
    tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
    tf_update_ops.append(tf.assign_add(global_step, 1))
    tf.logging.info('tf_update_ops: {}'.format(tf_update_ops))
    train_op = tf.group(tf_update_ops)
  else:
    tf_logits = lowering.export_to_tf_tensor(fully_replicated_logits)

  with mtf.utils.outside_all_rewrites():
    # Copy master variables to slices. Must be called first.
    restore_hook = mtf.MtfRestoreHook(lowering)
    if mode == tf.estimator.ModeKeys.TRAIN:
      saver = tf.train.Saver(
          tf.global_variables(),
          sharded=True,
          max_to_keep=10,
          keep_checkpoint_every_n_hours=2,
          defer_build=False,
          save_relative_paths=True)
      tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
      saver_listener = mtf.MtfCheckpointSaverListener(lowering)
      saver_hook = tf.train.CheckpointSaverHook(
          FLAGS.model_dir,
          save_steps=1000,
          saver=saver,
          listeners=[saver_listener])

      return tpu_estimator.TPUEstimatorSpec(
          tf.estimator.ModeKeys.TRAIN,
          loss=tf_loss,
          train_op=train_op,
          training_hooks=[restore_hook, saver_hook])
    elif mode == tf.estimator.ModeKeys.EVAL:

      def metric_fn(tf_logits):
        mean_logits = tf.metrics.mean(tf_logits)
        return {'mean_logits': mean_logits}

      eval_metrics = (metric_fn, [tf_logits])

      return tpu_estimator.TPUEstimatorSpec(
          tf.estimator.ModeKeys.EVAL,
          evaluation_hooks=[restore_hook],
          loss=tf_loss,
          eval_metrics=eval_metrics)
示例#4
0
def model_fn(features, labels, mode, params):
    # Get global step
    global_step = tf.train.get_global_step()

    # Construct mtf graph + mesh from params
    graph = mtf.Graph()
    mesh_shape = mtf.convert_to_shape(params["mesh_shape"])
    layout_rules = mtf.convert_to_layout_rules(params["layout"])

    # Mesh setup
    if params["use_tpu"]:
        var_placer, mesh_impl = simd_mesh_setup(params, mesh_shape,
                                                layout_rules)
    else:
        var_placer = None
        gpu_ids = params["gpu_ids"]
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
            mesh_shape, layout_rules, gpu_ids)

    # Trainable variable precision
    # Store to checkpoints in master type, train in slice type, compute in activation type
    if params["precision"] == "bfloat16":
        variable_dtype = mtf.VariableDType(master_dtype=tf.bfloat16,
                                           slice_dtype=tf.float32,
                                           activation_dtype=tf.bfloat16)
    else:
        variable_dtype = mtf.VariableDType(master_dtype=tf.float32,
                                           slice_dtype=tf.float32,
                                           activation_dtype=tf.float32)

    # Build mtf mesh object
    mesh = mtf.Mesh(graph, "my_mesh", var_placer)

    # Build mtf_features & seq length dict for getting number of microbatches
    # We need to pack inputs into a dict to pass into serialize_training_step
    features_dict = {"inputs": features, "labels": labels}
    sequence_length_dict = {
        "inputs": params["n_ctx"],
        "labels": params["n_ctx"]
    }

    params = add_mode_to_params(params, mode)
    batch_size = get_batch_size(params)

    batch_dim = mtf.Dimension("batch", batch_size)
    batch_dims = [batch_dim]
    feature_length = sequence_length_dict["inputs"]
    length_dim = mtf.Dimension("sequence", feature_length)

    mtf_features = {}
    for key, x in features_dict.items():
        if x is not None:
            feature_shape = mtf.Shape(batch_dims + [length_dim])
            if type(features_dict[key]) == dict:
                features_dict[key] = features_dict[key]["feature"]
            x = tf.cast(features_dict[key], tf.int32)
            x = tf.reshape(x, feature_shape.to_integer_list)
            mtf_features[key] = mtf.import_fully_replicated(mesh,
                                                            x,
                                                            feature_shape,
                                                            name=key)

    # Instantiate dict for dimensions, bias, etc that can be calculated here once then passed into model
    other_features = {}
    memory_length_dim = mtf.Dimension("memory_length", length_dim.size)

    attn_bias = biasmask_attn_weights(
        mesh, length_dim, memory_length_dim,
        variable_dtype) if params["causal"] else None

    # Add attn_bias into mtf_features
    other_features["attn_bias"] = attn_bias

    # Define other Dimensions that we'll need inside the model
    embd_dim = mtf.Dimension("embd", params["n_embd"])
    vocab_dim = mtf.Dimension("vocab", params["n_vocab"])
    # We need this because gathering when both the args have the same dimension in them breaks things
    # This dim is specifically for the weights
    # This prevents the "Einsum has lhs dimension without corresponding rhs or output dimension." error
    embed_sequence_dim = mtf.Dimension("embed_sequence", params["n_ctx"])

    other_features["embd_dim"] = embd_dim
    other_features["vocab_dim"] = vocab_dim
    other_features["embed_sequence_dim"] = embed_sequence_dim
    other_features["memory_length_dim"] = memory_length_dim

    if mode == tf.estimator.ModeKeys.PREDICT:
        # Set up the model for prediction
        inputs = mtf_features["inputs"]
        if params["remove_partial_sequences"] is None:
            params["remove_partial_sequences"] = False

        export = params.get("export", False)

        if not export:
            mtf_samples = sample_autoregressive(
                inputs,
                other_features=other_features,
                params=params,
                variable_dtype=variable_dtype,
                remove_partial_sequences=params["remove_partial_sequences"],
                stop_at_token=params["eos_id"],
                sampling_use_entmax=params['sampling_use_entmax'])

        else:
            with mtf.utils.outside_all_rewrites():
                with tf.variable_scope('gpt2'):
                    mtf_samples, loss, loss_batch = gpt2.model(
                        mtf_features,
                        other_features,
                        params,
                        mesh,
                        variable_dtype=variable_dtype,
                        context=None)

        mtf_samples = mtf.anonymize(mtf_samples)
        inputs = mtf.anonymize(inputs)
        lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True)
        inputs = lowering.export_to_tf_tensor(inputs)
        outputs = lowering.export_to_tf_tensor(mtf_samples)
        predictions = {"inputs": inputs, "outputs": outputs}

        def scaffold_fn():
            return tf.train.Scaffold(
                local_init_op=tf.group(
                    tf.train.Scaffold.default_local_init_op(),
                    lowering.copy_masters_to_slices(),
                    name="mtf_local_init_op"),
                ready_op=tf.concat([
                    tf.report_uninitialized_variables(),
                    resources.report_uninitialized_resources()
                ],
                                   axis=0,
                                   name="mtf_ready_op"))

        return tpu_estimator.TPUEstimatorSpec(
            mode=tf.estimator.ModeKeys.PREDICT,
            predictions=predictions,
            scaffold_fn=scaffold_fn,
            prediction_hooks=[mtf.MtfRestoreHook(lowering)])

    # We're not predicting, so we better be training or evaluating
    assert (mode == tf.estimator.ModeKeys.TRAIN
            or mode == tf.estimator.ModeKeys.EVAL)

    if mode == tf.estimator.ModeKeys.TRAIN:
        # Gets number of microbatches per batch for serialized training
        # if param tokens_per_mb_per_replica = None, this defaults to 1 and no microbatching is performed
        num_microbatches = int(
            mtf_transformer.utils.serialize_num_microbatches(
                batch_dim=batch_dim,
                sequence_length=sequence_length_dict,
                mesh_shape=mesh_shape,
                layout_rules=layout_rules,
                tokens_per_microbatch_per_replica=params[
                    "tokens_per_mb_per_replica"]))
    else:
        num_microbatches = 1

    params[
        "num_microbatches"] = num_microbatches  # Add num microbatches to params

    if num_microbatches > 1:

        # For serialize_training_step we need to modify the model to output results in a dict
        def serialized_fn(mtf_features):
            if params["model"] == "GPT":
                with tf.variable_scope('gpt2'):
                    logits, loss, loss_batch = gpt2.model(
                        mtf_features,
                        other_features,
                        params,
                        mesh,
                        variable_dtype=variable_dtype)
                return {
                    "logits": logits,
                    "loss": loss,
                    "loss_batch": loss_batch
                }
            else:
                raise Exception(
                    f"'{params['model']}' is not a valid model - please select from [GPT]"
                )

        # Serialize the training step - Gradients are accumulated locally and reduced once.
        var_grads, output_dict = mtf.serialize_training_step(
            mtf_features, serialized_fn, batch_dim, num_microbatches)
        loss = output_dict["loss"]
        loss_batch = output_dict["loss_batch"]
        logits = output_dict["logits"]
    else:
        # If we're not splitting into microbatches, return logits & loss as is
        if params["model"] == "GPT":
            with mtf.utils.outside_all_rewrites():
                with tf.variable_scope('gpt2'):
                    logits, loss, loss_batch = gpt2.model(
                        mtf_features,
                        other_features,
                        params,
                        mesh,
                        variable_dtype=variable_dtype,
                        context=None)
        else:
            raise Exception(
                f"'{params['model']}' is not a valid model - please select from [GPT]"
            )

    # Auto layout generation
    if params["auto_layout"]:
        auto_layout(graph, mesh_shape, logits, loss)
    if params["auto_layout_and_mesh_shape"]:
        auto_layout_and_mesh_shape(graph, params["num_cores"], logits, loss)

    if mode == tf.estimator.ModeKeys.TRAIN:
        # In TRAIN mode, get optimizer
        if params["num_microbatches"] > 1:
            # If we are splitting the batch into microbatches, var grads are created in the serialize_training_step fn
            # So we pass them in here
            _, update_ops, var_grads = get_optimizer(
                mesh,
                loss,
                params,
                variable_dtype=variable_dtype,
                inp_var_grads=var_grads)
        else:
            # Otherwise, they are created in the get_optimizer fn, so we leave inp_var_grads blank
            _, update_ops, var_grads = get_optimizer(
                mesh, loss, params, variable_dtype=variable_dtype)
        # Log summaries to tensorboard
        mtf.scalar_summary("loss", loss)
        # Log gradients if in params
        if params["log_grads"] not in [None, False]:
            for g in var_grads:
                grad_norm = mtf.sqrt(mtf.reduce_sum(mtf.square(g)))
                mtf.scalar_summary("grads/norm" + g.name[:-2], grad_norm)
    else:
        # For now, we can only export fully-replicated tensors.
        # This has to be done before lowering or they will not be included in the graph
        mean_logits = mtf.reduce_mean(logits, reduced_dim=vocab_dim)
        max_logits = mtf.argmax(logits, vocab_dim)
        del logits
        fully_replicated_mean_logits = mtf.anonymize(mean_logits)
        fully_replicated_max_logits = mtf.anonymize(max_logits)
        fully_replicated_loss_batch = mtf.anonymize(loss_batch)

    # Gets & prints info about no. trainable vars in the model & dimension names
    get_graph_info(graph)

    # 'lowers' mtf tensors into a tf graph - this enables us to export results as tf tensors
    lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True)
    tf_loss = lowering.export_to_tf_tensor(loss)
    tf_loss = tf.cast(tf_loss, tf.float32)

    if mode == tf.estimator.ModeKeys.TRAIN:
        # Use our patched version until mtf updates theirs
        host_call = create_host_call(params['model_path'])
        mtf.utils.remove_summaries()

        # Creates train_op
        tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
        tf_update_ops.append(tf.assign_add(
            global_step, 1))  # Need to manually increment global_step
        tf.logging.info(f"tf_update_ops: {tf_update_ops}")
        train_op = tf.group(tf_update_ops)
    else:
        tf_mean_logits = lowering.export_to_tf_tensor(
            fully_replicated_mean_logits)
        tf_max_logits = lowering.export_to_tf_tensor(
            fully_replicated_max_logits)
        tf_loss_batch = tf.to_float(
            lowering.export_to_tf_tensor(fully_replicated_loss_batch))

    with mtf.utils.outside_all_rewrites():
        # Copy master variables to slices. Must be called first.
        restore_hook = mtf.MtfRestoreHook(lowering)
        if mode == tf.estimator.ModeKeys.TRAIN:
            # Set up the checkpoint server and return the TPUEstimatorSpec
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=10,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                params["model_path"],
                save_steps=params["steps_per_checkpoint"],
                saver=saver,
                listeners=[saver_listener])

            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                host_call=host_call,
                train_op=train_op,
                training_hooks=[restore_hook, saver_hook])

        elif mode == tf.estimator.ModeKeys.EVAL:
            # Evaluation metrics
            def _perplexity(loss):
                perplexity = tf.exp(loss)
                return tf.metrics.mean(perplexity)

            def _bits_per_byte(loss):
                bpb = loss * (0.29335 / math.log(2))
                return tf.metrics.mean(bpb)

            def _metric_fn(tf_mean_logits, tf_loss_batch):
                mean_logits = tf.metrics.mean(tf_mean_logits)
                loss = tf.reduce_mean(tf_loss_batch)
                perp = _perplexity(loss)
                bpb = _bits_per_byte(loss)
                return {
                    "mean_logits": mean_logits,
                    "perplexity": perp,
                    "bits per byte": bpb
                }

            def _lambada_metric_fn(labels, tf_max_logits, tf_loss_batch):
                eos_token = params["eos_id"]
                answer_positions = tf.where(
                    tf.math.not_equal(labels, eos_token))

                correct_answers = tf.gather_nd(
                    tf.math.equal(tf_max_logits, labels), answer_positions)
                accuracy = tf.metrics.mean(tf.cast(correct_answers,
                                                   tf.float32))

                # I guess tf_loss_batch has z_loss and maybe other stuff added to it
                # so maybe this should be calculated separately in the future
                answer_loss = tf.gather_nd(tf_loss_batch, answer_positions)
                log_perplexity = tf.metrics.mean(answer_loss)

                return {
                    "lambada_acc": accuracy,
                    "lambada_log_ppl": log_perplexity
                }

            eval_task = params["eval_task"]
            if eval_task == "lambada":
                eval_metrics = (_lambada_metric_fn,
                                [labels, tf_max_logits, tf_loss_batch])
            else:
                eval_metrics = (_metric_fn, [tf_mean_logits, tf_loss_batch])

            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.EVAL,
                evaluation_hooks=[restore_hook],
                loss=tf_loss,
                eval_metrics=eval_metrics)
示例#5
0
    def my_model_fn(features, labels, mode, params=None, config=None):
        """Estimator model function.
        Args:
          features: dictionary where keys are strings like "inputs" and "targets"
            and the values are the actual values of "inputs". See TPUEstimator's
            docs for more information
          labels: ignored argument
          mode: a tf.estimator.ModeKeys
          params: dictionary containing the key "context"
          config: ignored argument
        Returns:
          a TPUEstimatorSpec
        """
        del labels, config
        global_step = tf.train.get_global_step()
        if use_tpu and "context" in params:
            ctx = params["context"]
            num_hosts = ctx.num_hosts
            host_placement_fn = ctx.tpu_host_placement_function
            device_list = [
                host_placement_fn(host_id=t) for t in range(num_hosts)
            ]
            # TODO(ylc): Better estimation of replica cache size?
            replica_cache_size = 300 * 1000000  # 300M per replica
            # Worker 0 caches all the TPU binaries.
            worker0_mem = replica_cache_size * ctx.num_replicas
            devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
            var_placer = mtf.utils.BalancedVariablePlacer(
                device_list, devices_memeory_usage)
            # deprecated mesh_devices = [""] * mesh_shape.size
            physical_shape = list(
                params["context"].device_assignment.topology.mesh_shape)
            logical_to_physical = mtf.simd_mesh_impl.auto_logical_to_physical_tpu(
                mesh_shape.to_integer_list, physical_shape)
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape,
                layout_rules,
                mesh_devices,
                ctx.device_assignment,
                logical_to_physical=logical_to_physical)
        else:
            var_placer = None
            # deprecated mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, layout_rules, mesh_devices)

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh", var_placer)

        mtf_features = {}
        for key, x in features.items():
            outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size)
            batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size)
            # Some auxiliary features may have been generated in packing.
            # The names of these new features are of the form
            #   "<original_feature_name>_<suffix>", e.g. "inputs_segmentation".
            #   We look up the lengths based on the original feature name, without
            #   the "_<suffix>".
            feature_length = sequence_length[key.split("_")[0]]
            length_dim = mtf.Dimension("length", feature_length)
            ensemble_dims = ([mtf.Dimension("ensemble", ensemble_inputs)]
                             if ensemble_inputs else [])
            feature_shape = mtf.Shape(ensemble_dims +
                                      [outer_batch_dim, batch_dim, length_dim])
            x = tf.cast(features[key], tf.int32)
            x = tf.reshape(x, feature_shape.to_integer_list)
            if not use_tpu:
                tf.logging.info("feature %s : %s" % (key, x))
                x = tf.Print(x, [x],
                             "import feature %s" % key,
                             summarize=1000,
                             first_n=10)
            mtf_features[key] = mtf.import_fully_replicated(mesh,
                                                            x,
                                                            feature_shape,
                                                            name=key)
            if key == "targets" or key == "codeprefixedtargets" or key == "controlcode":
                anon_targets = mtf.anonymize(mtf_features[key])

        if mode == tf.estimator.ModeKeys.PREDICT:

            def _feature_shape(key):
                feature_length = sequence_length[key.split("_")[0]]
                return mtf.Shape([
                    mtf.Dimension("batch", batch_size),
                    mtf.Dimension("length", feature_length)
                ])

            mtf_features = {
                k: mtf.reshape(v, _feature_shape(k))
                for k, v in six.iteritems(mtf_features)
            }
            inputs = mtf_features["inputs"]

            if attribute_embedding:
                attributes = mtf_features["attribute"]
            else:
                attributes = None

            if has_partial_sequences:
                controlcodes = mtf_features["controlcode"]
            else:
                controlcodes = None

            if predict_fn:
                mtf_samples = predict_fn(model=transformer_model,
                                         features=mtf_features,
                                         variable_dtype=get_variable_dtype())
            elif isinstance(transformer_model, transformer.Unitransformer):
                # pad so that there is enough room for the targets
                inputs = mtf.pad(inputs, [0, sequence_length["targets"]],
                                 length_dim.name)
                mtf_samples = transformer_model.sample_autoregressive(
                    inputs,
                    variable_dtype=get_variable_dtype(),
                    remove_partial_sequences=True)
            elif isinstance(transformer_model, Bitransformer_ll):
                mtf_samples = transformer_model.decode(
                    inputs,
                    attributes=attributes,
                    controlcodes=controlcodes,
                    has_partial_sequences=has_partial_sequences,
                    remove_partial_sequences=remove_partial_sequences,
                    variable_dtype=get_variable_dtype())  #
            elif isinstance(
                    transformer_model,
                (transformer.Bitransformer, transformer.StudentTeacher)):
                mtf_samples = transformer_model.decode(
                    inputs, variable_dtype=get_variable_dtype())
            else:
                raise ValueError("unrecognized class")
            mtf_samples = mtf.anonymize(mtf_samples)
            inputs = mtf.anonymize(inputs)
            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)
            inputs = lowering.export_to_tf_tensor(inputs)
            outputs = lowering.export_to_tf_tensor(mtf_samples)
            predictions = {"inputs": inputs, "outputs": outputs}

            # When exporting a model, we need to communicate to TF-Serving that
            # master variables need to be copied to their slave slice variables.
            # Estimator uses a Scaffold's "local_init_op" for this purpose, so we
            # augment the default "local_init_op" here.
            #
            # The "ready_op" is also constructed here to ensure the variables
            # initialized by "local_init_op" are the same ones checked by "ready_op".
            #
            # WARNING: Any variables created outside of this model_fn()
            # (e.g. tpu_estimator/iterations_per_loop) will NOT be initialized nor
            # checked by these ops.
            def scaffold_fn():
                return tf.train.Scaffold(
                    local_init_op=tf.group(
                        tf.train.Scaffold.default_local_init_op(),
                        lowering.copy_masters_to_slices(),
                        name="mtf_local_init_op"),
                    ready_op=tf.concat([
                        tf.report_uninitialized_variables(),
                        resources.report_uninitialized_resources()
                    ],
                                       axis=0,
                                       name="mtf_ready_op"))

            return tpu_estimator.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.PREDICT,
                predictions=predictions,
                scaffold_fn=scaffold_fn,
                prediction_hooks=[mtf.MtfRestoreHook(lowering)])

        assert (mode == tf.estimator.ModeKeys.TRAIN
                or mode == tf.estimator.ModeKeys.EVAL)

        def logits_and_loss(mtf_features):
            """Compute logits and loss.
            Args:
              mtf_features: a dictionary
            Returns:
              logits: a mtf.Tensor
              loss: a mtf.Tensor
            """
            if model_type == "lm":  # TOTRY Adapt that to our case
                if "inputs" in mtf_features:
                    mtf_features = _dynamic_text2self(mtf_features)
                _, _, length_dim = mtf_features["targets"].shape
                inputs = mtf.shift(mtf_features["targets"],
                                   offset=1,
                                   dim=length_dim,
                                   wrap=False)
            else:
                inputs = mtf_features["inputs"]

            if attribute_embedding:
                attributes = mtf_features["attribute"]
            else:
                attributes = None

            if control_codes:
                codeprefixedtargets = mtf_features["codeprefixedtargets"]
            else:
                codeprefixedtargets = None

            if isinstance(transformer_model, transformer.Unitransformer):
                position_kwargs = dict(
                    sequence_id=mtf_features.get("targets_segmentation", None),
                    position=mtf_features.get("targets_position", None),
                )
            elif isinstance(transformer_model, transformer.Bitransformer
                            ) or model_type == "bi_student_teacher":
                if control_codes:
                    position_kwargs = dict(
                        encoder_sequence_id=mtf_features.get(
                            "inputs_segmentation", None),
                        decoder_sequence_id=mtf_features.get(
                            "codeprefixedtargets_segmentation", None),
                        decoder_subsequence_id=mtf_features.get(
                            "codeprefixedtargets_subsegmentation", None),
                        encoder_position=mtf_features.get(
                            "inputs_position", None),
                        decoder_position=mtf_features.get(
                            "codeprefixedtargets_position", None),
                    )
                else:
                    position_kwargs = dict(
                        encoder_sequence_id=mtf_features.get(
                            "inputs_segmentation", None),
                        decoder_sequence_id=mtf_features.get(
                            "targets_segmentation", None),
                        decoder_subsequence_id=mtf_features.get(
                            "targets_subsegmentation", None),
                        encoder_position=mtf_features.get(
                            "inputs_position", None),
                        decoder_position=mtf_features.get(
                            "targets_position", None),
                    )
            else:
                raise ValueError("unrecognized class")

            if isinstance(transformer_model, Bitransformer_ll):
                if cycle_consistency_loss:
                    logits_ae, l_ae = transformer_model.call_simple(
                        inputs=inputs,
                        targets=mtf_features["targets"],
                        compute_loss=True,
                        attributes=attributes,
                        codeprefixedtargets=codeprefixedtargets,
                        mode=mode,
                        variable_dtype=get_variable_dtype(),
                        **position_kwargs)

                    if has_partial_sequences:
                        controlcodes = mtf_features["controlcode"]
                    else:
                        controlcodes = None

                    with gin.config_scope('training'):
                        mtf_samples = transformer_model.decode(
                            inputs,
                            attributes=attributes,
                            controlcodes=controlcodes,
                            has_partial_sequences=has_partial_sequences,
                            remove_partial_sequences=remove_partial_sequences,
                            variable_dtype=get_variable_dtype())
                        # mtf_samples = mtf.anonymize(mtf_samples)
                    outputs = mtf_samples

                    logits_cycle, l_cycle = transformer_model.call_simple(
                        inputs=outputs,
                        targets=mtf_features["targets"],
                        compute_loss=True,
                        attributes=attributes,
                        codeprefixedtargets=codeprefixedtargets,
                        mode=mode,
                        variable_dtype=get_variable_dtype(),
                        **position_kwargs)

                    loss_ae_cycle = lambda_ae * l_ae + lambda_cycle * l_cycle
                    return logits_cycle, loss_ae_cycle
                else:
                    return transformer_model.call_simple(
                        inputs=inputs,
                        targets=mtf_features["targets"],
                        compute_loss=True,
                        attributes=attributes,
                        codeprefixedtargets=codeprefixedtargets,
                        mode=mode,
                        variable_dtype=get_variable_dtype(),
                        **position_kwargs)
            else:
                return transformer_model.call_simple(
                    inputs=inputs,
                    targets=mtf_features["targets"],
                    compute_loss=True,
                    mode=mode,
                    variable_dtype=get_variable_dtype(),
                    num_microbatches=num_microbatches,
                    **position_kwargs)

        if mode == tf.estimator.ModeKeys.TRAIN:
            num_microbatches = serialize_num_microbatches(
                batch_dim, sequence_length, mesh_shape, layout_rules)
            if num_microbatches > 1:

                def serialized_fn(mtf_features):
                    return {
                        "loss":
                        (logits_and_loss(mtf_features)[1] / num_microbatches)
                    }

                var_grads, loss_dict = mtf.serialize_training_step(
                    mtf_features, serialized_fn, batch_dim, num_microbatches)
                loss = loss_dict["loss"]
            else:
                loss = logits_and_loss(mtf_features)[1]
                var_grads = mtf.gradients(
                    [loss], [v.outputs[0] for v in graph.trainable_variables])

            if tpu_summaries:
                mtf.scalar_summary("loss", loss)

            if callable(learning_rate_schedule):
                # the following happens on CPU since TPU can't handle summaries.
                with mtf.utils.outside_all_rewrites():
                    learning_rate = learning_rate_schedule(
                        step=tf.train.get_global_step())
                    tf.summary.scalar("learning_rate", learning_rate)
            else:
                learning_rate = learning_rate_schedule

            if isinstance(variable_filter, str):
                pattern = re.compile(variable_filter)
                variable_filter_fn = lambda v: pattern.search(v.name)
            elif variable_filter is None:
                variable_filter_fn = lambda v: True
            elif callable(variable_filter):
                variable_filter_fn = variable_filter
            else:
                raise ValueError(
                    "variable_filter must be None, a string, or a callable function"
                )
            trainable_vars = [
                v for v in graph.trainable_variables if variable_filter_fn(v)
            ]
            trainable_var_grads = [
                g for g, v in zip(var_grads, graph.trainable_variables)
                if variable_filter_fn(v)
            ]
            if len(trainable_vars) != len(graph.trainable_variables):
                tf.logging.info("Variables being trained:")
                tf.logging.info([v.name for v in trainable_vars])
                tf.logging.info("Variables not being trained:")
                tf.logging.info([
                    v.name for v in graph.trainable_variables
                    if not variable_filter_fn(v)
                ])

            update_ops = optimizer(learning_rate=learning_rate).apply_grads(
                trainable_var_grads, trainable_vars)

            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)

            tf_loss = lowering.export_to_tf_tensor(loss)
            tf_loss = tf.cast(tf_loss, tf.float32)
            if not use_tpu:
                tf_loss = tf.Print(
                    tf_loss, [tf_loss, tf.train.get_global_step()],
                    "step, tf_loss")

            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            train_op = tf.group(tf_update_ops)

            if hasattr(transformer_model, "initialize"):
                with mtf.utils.outside_all_rewrites():
                    transformer_model.initialize()

            if tpu_summaries:
                # has to be outside of
                # with mtf.utils.outside_all_rewrites()
                host_call = mtf.utils.create_host_call(model_dir)
                mtf.utils.remove_summaries()
            else:
                host_call = None

            with mtf.utils.outside_all_rewrites():

                if init_checkpoint:
                    ckpt_vars = {
                        v
                        for v, _ in tf.train.list_variables(init_checkpoint)
                    }
                    global_vars = {v.op.name for v in tf.global_variables()}
                    restore_vars = ckpt_vars.intersection(global_vars)
                    tf.logging.info("Initializing variables from %s:",
                                    init_checkpoint)
                    tf.logging.debug("\n".join(sorted(restore_vars)))
                    tf.logging.info("Variables in %s but not in graph:",
                                    init_checkpoint)
                    tf.logging.info("\n".join(sorted(ckpt_vars - global_vars)))
                    tf.logging.info("Variables in graph but not in %s:",
                                    init_checkpoint)
                    tf.logging.info("\n".join(sorted(global_vars - ckpt_vars)))
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  {v: v
                                                   for v in restore_vars})

                # Copy master variables to slices. Must be called first.
                restore_hook = mtf.MtfRestoreHook(lowering)
                saver = tf.train.Saver(tf.global_variables(),
                                       sharded=True,
                                       max_to_keep=keep_checkpoint_max,
                                       keep_checkpoint_every_n_hours=2,
                                       defer_build=False,
                                       save_relative_paths=True)
                tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
                saver_listener = mtf.MtfCheckpointSaverListener(lowering)
                saver_hook = tf.train.CheckpointSaverHook(
                    model_dir,
                    save_steps=save_checkpoints_steps,
                    saver=saver,
                    listeners=[saver_listener])
                gin_config_saver_hook = gin.tf.GinConfigSaverHook(
                    model_dir,
                    summarize_config=True,
                    include_step_in_filename=False)

                if use_tpu:
                    return tpu_estimator.TPUEstimatorSpec(
                        mode=tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        host_call=host_call,
                        training_hooks=[
                            restore_hook,
                            saver_hook,
                            gin_config_saver_hook,
                        ])
                else:
                    return tf.estimator.EstimatorSpec(
                        tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        training_chief_hooks=[
                            restore_hook,
                            saver_hook,
                            gin_config_saver_hook,
                        ])
        elif mode == tf.estimator.ModeKeys.EVAL:
            logits, loss = logits_and_loss(mtf_features)
            anon_logits = mtf.anonymize(logits)
            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)
            tf_loss = tf.cast(lowering.export_to_tf_tensor(loss), tf.float32)
            tf_loss = tf.cast(tf_loss, tf.float32)
            tf_logits = tf.cast(lowering.export_to_tf_tensor(anon_logits),
                                tf.float32)

            def simple_metrics(logits, labels):
                """Simple metrics for teacher-forced eval."""
                weights = tf.cast(tf.not_equal(labels, 0), tf.float32)
                xent = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=labels, logits=logits)
                predictions = tf.cast(tf.argmax(logits, axis=-1), labels.dtype)
                token_correct = tf.cast(tf.equal(predictions, labels),
                                        tf.float32) * weights
                sequence_correct = tf.to_float(
                    tf.equal(tf.reduce_sum(token_correct, -1),
                             tf.reduce_sum(weights, -1)))
                sequence_weights = tf.to_float(
                    tf.not_equal(tf.reduce_sum(weights, -1), 0))
                return {
                    "neg_log_perplexity":
                    tf.metrics.mean(-xent, weights),
                    "token_accuracy":
                    tf.metrics.mean(token_correct, weights),
                    "sequence_accuracy":
                    tf.metrics.mean(sequence_correct, sequence_weights)
                }

            labels = lowering.export_to_tf_tensor(anon_targets)
            eval_metrics = (simple_metrics, [tf_logits, labels])
            with mtf.utils.outside_all_rewrites():
                restore_hook = mtf.MtfRestoreHook(lowering)
            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.EVAL,
                evaluation_hooks=[restore_hook],
                loss=tf_loss,
                eval_metrics=eval_metrics)
示例#6
0
def my_model_fn(features,
                labels,
                mode,
                params=None,
                config=None):
  """Estimator model function.

  Args:
    features: input features dictionary
    labels: ignored
    mode: a tf.estimator.ModeKeys
    params: something
    config: something
  Returns:
    something
  """
  del labels, config
  use_tpu = FLAGS.tpu
  global_step = tf.train.get_global_step()

  mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
  layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)
  if use_tpu:
    ctx = params["context"]
    num_hosts = ctx.num_hosts
    host_placement_fn = ctx.tpu_host_placement_function
    device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
    # TODO(ylc): Better estimation of replica cache size?
    replica_cache_size = 300 * 1000000  # 300M per replica
    # Worker 0 caches all the TPU binaries.
    worker0_mem = replica_cache_size * ctx.num_replicas
    devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
    var_placer = mtf.utils.BalancedVariablePlacer(device_list,
                                                  devices_memeory_usage)
    mesh_devices = [""] * mesh_shape.size
    mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
        mesh_shape, layout_rules, mesh_devices, ctx.device_assignment)
  else:
    var_placer = None
    mesh_devices = [""] * mesh_shape.size
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        mesh_shape, layout_rules, mesh_devices)

  graph = mtf.Graph()
  mesh = mtf.Mesh(graph, "my_mesh", var_placer)

  model = transformer.Bitransformer(
      encoder_layer_stack=layer_stack(include_encdec_attention=False),
      decoder_layer_stack=layer_stack(include_encdec_attention=True),
      encoder_d_model=FLAGS.d_model,
      decoder_d_model=FLAGS.d_model,
      input_vocab_size=transformer_dataset.padded_vocab_size(
          transformer_dataset.inputs_vocab_size(FLAGS.dataset)),
      output_vocab_size=transformer_dataset.padded_vocab_size(
          transformer_dataset.targets_vocab_size(FLAGS.dataset)),
      max_length=FLAGS.max_length,
      shared_embedding=False,
      shared_embedding_and_softmax_weights=True,
      label_smoothing=FLAGS.label_smoothing,
      layout=FLAGS.layout,
      mesh_shape=FLAGS.mesh_shape)

  inputs = import_feature(features, mesh, "inputs")

  # Data-types used for variables and activations
  # See comments in the FLAGS
  master_dtype = tf.as_dtype(FLAGS.master_dtype)
  if FLAGS.slice_dtype:
    slice_dtype = tf.as_dtype(FLAGS.slice_dtype)
  elif not FLAGS.tpu or FLAGS.mode == "train":
    slice_dtype = tf.float32
  else:
    slice_dtype = tf.bfloat16
  if FLAGS.activation_dtype:
    activation_dtype = tf.as_dtype(FLAGS.activation_dtype)
  else:
    activation_dtype = tf.bfloat16 if FLAGS.tpu else tf.float32
  variable_dtype = mtf.VariableDType(master_dtype=master_dtype,
                                     slice_dtype=slice_dtype,
                                     activation_dtype=activation_dtype)

  # PREDICT mode
  if mode == tf.estimator.ModeKeys.PREDICT:
    mtf_samples = model.decode(
        inputs,
        variable_dtype=variable_dtype,
        beam_size=FLAGS.beam_size,
        alpha=FLAGS.alpha,
        temperature=FLAGS.temperature)
    mtf_samples = mtf.anonymize(mtf_samples)
    lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=FLAGS.autostack)
    outputs = lowering.export_to_tf_tensor(mtf_samples)
    predictions = {
        "outputs": outputs
    }
    return tpu_estimator.TPUEstimatorSpec(
        mode=tf.estimator.ModeKeys.PREDICT,
        predictions=predictions,
        prediction_hooks=[mtf.MtfRestoreHook(lowering)])

  targets = import_feature(features, mesh, "targets")
  anon_targets = mtf.anonymize(targets)
  logits, loss = model.call_simple(
      inputs=inputs,
      targets=targets,
      compute_loss=True,
      mode=mode,
      variable_dtype=variable_dtype,
      encoder_sequence_id=import_feature(features, mesh, "inputs_segmentation"),
      decoder_sequence_id=import_feature(
          features, mesh, "targets_segmentation"),
      encoder_position=import_feature(features, mesh, "inputs_position"),
      decoder_position=import_feature(features, mesh, "targets_position")
  )

  if use_tpu and logits is not None:
    logits = mtf.anonymize(logits)

  # TRAIN mode
  if mode == tf.estimator.ModeKeys.TRAIN:
    var_grads = mtf.gradients(
        [loss], [v.outputs[0] for v in graph.trainable_variables])
    optimizer = mtf.optimize.AdafactorOptimizer()
    update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables)

  lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=FLAGS.autostack)

  tf_loss = lowering.export_to_tf_tensor(loss)
  tf_loss = tf.to_float(tf_loss)
  if not use_tpu:
    tf_loss = tf.Print(
        tf_loss, [tf_loss, tf.train.get_global_step()], "step, tf_loss")
  if logits and mode != tf.estimator.ModeKeys.TRAIN:
    tf_logits = lowering.export_to_tf_tensor(logits)

  if mode == tf.estimator.ModeKeys.TRAIN:
    tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
    tf_update_ops.append(tf.assign_add(global_step, 1))
    train_op = tf.group(tf_update_ops)

  with mtf.utils.outside_all_rewrites():
    # Copy master variables to slices. Must be called first.
    restore_hook = mtf.MtfRestoreHook(lowering)
    saver = tf.train.Saver(
        tf.global_variables(),
        sharded=True,
        max_to_keep=10,
        keep_checkpoint_every_n_hours=2,
        defer_build=False,
        save_relative_paths=True)
    tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
    saver_listener = mtf.MtfCheckpointSaverListener(lowering)
    saver_hook = tf.train.CheckpointSaverHook(
        FLAGS.model_dir,
        save_steps=1000,
        saver=saver,
        listeners=[saver_listener])

    if mode == tf.estimator.ModeKeys.TRAIN:
      if use_tpu:
        return tpu_estimator.TPUEstimatorSpec(
            mode=tf.estimator.ModeKeys.TRAIN,
            loss=tf_loss,
            train_op=train_op,
            training_hooks=[restore_hook, saver_hook])
      else:
        return tf.estimator.EstimatorSpec(
            tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op,
            training_chief_hooks=[restore_hook, saver_hook])
    elif mode == tf.estimator.ModeKeys.EVAL:
      def padded_neg_log_perplexity(logits, labels):
        weights = tf.to_float(tf.not_equal(labels, 0))
        xent = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=labels, logits=logits)
        return {"neg_log_perplexity": tf.metrics.mean(-xent, weights)}
      labels = lowering.export_to_tf_tensor(anon_targets)
      eval_metrics = (padded_neg_log_perplexity, [tf_logits, labels])
      return tpu_estimator.TPUEstimatorSpec(
          tf.estimator.ModeKeys.EVAL,
          evaluation_hooks=[restore_hook],
          loss=tf_loss,
          eval_metrics=eval_metrics)
示例#7
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        # MTF setup.
        graph = mtf.Graph()
        mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
        layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)

        ctx = params["context"]
        num_hosts = ctx.num_hosts
        host_placement_fn = ctx.tpu_host_placement_function
        device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
        tf.logging.info("device_list = %s" % device_list, )
        replica_cache_size = 300 * 1000000  # 300M per replica
        # Worker 0 caches all the TPU binaries.
        worker0_mem = replica_cache_size * ctx.num_replicas
        devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
        var_placer = mtf.utils.BalancedVariablePlacer(device_list,
                                                      devices_memeory_usage)
        mesh_devices = [""] * mesh_shape.size
        mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules,
                                                    mesh_devices,
                                                    ctx.device_assignment)
        mesh = mtf.Mesh(graph, "bert_mesh", var_placer)

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        label_ids = features["label_ids"]
        is_real_example = None
        if "is_real_example" in features:
            is_real_example = tf.cast(features["is_real_example"],
                                      dtype=tf.float32)
        else:
            is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)

        batch_size = input_ids.get_shape()[0].value
        batch_dim = mtf.Dimension("batch", batch_size)
        seq_length = input_ids.get_shape()[1].value
        seq_dim = mtf.Dimension("seq", seq_length)
        num_labels_dim = mtf.Dimension("seq", num_labels)
        mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids,
                                             [batch_dim, seq_dim])
        mtf_input_mask = mtf.import_tf_tensor(mesh, input_mask,
                                              [batch_dim, seq_dim])
        mtf_segment_ids = mtf.import_tf_tensor(mesh, segment_ids,
                                               [batch_dim, seq_dim])
        mtf_label_ids = mtf.import_tf_tensor(mesh, label_ids, [batch_dim])

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        (total_loss, per_example_loss, logits,
         probabilities) = create_model(bert_config, is_training, mtf_input_ids,
                                       mtf_input_mask, mtf_segment_ids,
                                       mtf_label_ids, num_labels_dim,
                                       layout_rules, mesh_shape)
        total_loss = mtf.anonymize(total_loss)
        per_example_loss = mtf.anonymize(per_example_loss)
        logits = mtf.anonymize(logits)

        if mode == tf.estimator.ModeKeys.TRAIN:
            _, update_ops = optimization_lib.create_optimizer(
                total_loss,
                learning_rate,
                num_train_steps,
                num_warmup_steps,
                max_optimized_variable_size=FLAGS.max_optimized_variable_size,
                optimizer=FLAGS.optimizer,
                clip_gradients=FLAGS.clip_gradients)

        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        tf_loss = tf.to_float(lowering.export_to_tf_tensor(total_loss))

        if mode == tf.estimator.ModeKeys.TRAIN:
            global_step = tf.train.get_global_step()
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
            train_op = tf.group(tf_update_ops)
        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(per_example_loss, label_ids, logits,
                          is_real_example):
                predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
                accuracy = tf.metrics.accuracy(labels=label_ids,
                                               predictions=predictions,
                                               weights=is_real_example)
                loss = tf.metrics.mean(values=per_example_loss,
                                       weights=is_real_example)
                return {
                    "eval_accuracy": accuracy,
                    "eval_loss": loss,
                }

            eval_metrics = (metric_fn, [
                lowering.export_to_tf_tensor(per_example_loss), label_ids,
                lowering.export_to_tf_tensor(logits), is_real_example
            ])

        tvars = tf.trainable_variables()
        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (assignment_map, initialized_variable_names
             ) = bert_lib.get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        with mtf.utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            if mode == tf.estimator.ModeKeys.TRAIN:
                saver = tf.train.Saver(tf.global_variables(),
                                       sharded=True,
                                       max_to_keep=10,
                                       keep_checkpoint_every_n_hours=2,
                                       defer_build=False,
                                       save_relative_paths=True)
                tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
                saver_listener = mtf.MtfCheckpointSaverListener(lowering)
                saver_hook = tf.train.CheckpointSaverHook(
                    FLAGS.output_dir,
                    save_steps=1000,
                    saver=saver,
                    listeners=[saver_listener])

                return tf.estimator.tpu.TPUEstimatorSpec(
                    mode,
                    loss=tf_loss,
                    train_op=train_op,
                    training_hooks=[restore_hook, saver_hook],
                    scaffold_fn=scaffold_fn)
            elif mode == tf.estimator.ModeKeys.EVAL:
                return tf.estimator.tpu.TPUEstimatorSpec(
                    mode,
                    evaluation_hooks=[restore_hook],
                    loss=tf_loss,
                    eval_metrics=eval_metrics,
                    scaffold_fn=scaffold_fn)
            else:
                return tf.estimator.tpu.TPUEstimatorSpec(
                    mode,
                    prediction_hooks=[restore_hook],
                    predictions={
                        "probabilities":
                        lowering.export_to_tf_tensor(probabilities)
                    },
                    scaffold_fn=scaffold_fn)
示例#8
0
    def estimator_model_fn(cls,
                           hparams,
                           features,
                           labels,
                           mode,
                           config=None,
                           params=None,
                           decode_hparams=None):
        hparams = copy.deepcopy(hparams)
        use_tpu = params and params.get("use_tpu", False)
        hparams.use_tpu = use_tpu
        # merge decode_hparams into hparams if present
        if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None:
            for k, v in six.iteritems(decode_hparams.values()):
                if hasattr(hparams, k) and getattr(hparams, k) != v:
                    tf.logging.warning(
                        "Overriding hparams.%s with %s from decode_hparams" %
                        (k, v))
                setattr(hparams, k, v)

        # Instantiate model
        data_parallelism = None
        if not use_tpu and config:
            data_parallelism = config.data_parallelism
        model = cls(hparams,
                    mode,
                    data_parallelism=data_parallelism,
                    decode_hparams=decode_hparams)

        global_step = tf.train.get_global_step()

        mesh_shape = mtf.convert_to_shape(hparams.mesh_shape)
        layout_rules = mtf.convert_to_layout_rules(hparams.layout)
        if use_tpu:
            ctx = params["context"]
            num_hosts = ctx.num_hosts
            host_placement_fn = ctx.tpu_host_placement_function
            device_list = [
                host_placement_fn(host_id=t) for t in range(num_hosts)
            ]
            # TODO(ylc): Better estimation of replica cache size?
            replica_cache_size = 300 * 1000000  # 300M per replica
            # Worker 0 caches all the TPU binaries.
            worker0_mem = replica_cache_size * ctx.num_replicas
            devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
            var_placer = mtf.utils.BalancedVariablePlacer(
                device_list, devices_memeory_usage)
            mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape, layout_rules, mesh_devices, ctx.device_assignment)
        else:
            var_placer = None
            if len(data_parallelism.ps_devices) == 1:
                mesh_devices = [""] * mesh_shape.size
            else:
                assert len(data_parallelism.ps_devices) == mesh_shape.size
                mesh_devices = data_parallelism.ps_devices
            mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, layout_rules, mesh_devices)

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh", var_placer)
        # PREDICT mode
        if mode == tf.estimator.ModeKeys.PREDICT:
            return model.estimator_spec_predict(features, mesh, mesh_impl,
                                                use_tpu)

        logits, loss = model.mtf_model_fn(features, mesh)
        if use_tpu and logits is not None:
            logits = mtf.anonymize(logits)

        # TRAIN mode
        if mode == tf.estimator.ModeKeys.TRAIN:
            var_grads = mtf.gradients(
                [loss], [v.outputs[0] for v in graph.trainable_variables])
            lr = learning_rate.learning_rate_schedule(hparams)
            tf.summary.scalar("learning_rate", lr)
            mtf_lr = mtf.import_tf_tensor(
                mesh, tf.convert_to_tensor(lr, dtype=tf.float32),
                mtf.Shape([]))
            optimizer = mtf.optimize.make_optimizer(hparams, mtf_lr)
            update_ops = []
            for grad, var in zip(var_grads, graph.trainable_variables):
                update_ops.extend(optimizer.apply_grad(grad, var))

        lowering = mtf.Lowering(graph, {mesh: mesh_impl})

        tf_loss = lowering.export_to_tf_tensor(loss)
        tf_loss = tf.to_float(tf_loss)
        if logits and mode != tf.estimator.ModeKeys.TRAIN:
            tf_logits = lowering.export_to_tf_tensor(logits)

        if mode == tf.estimator.ModeKeys.TRAIN:
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            # tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
            train_op = tf.group(tf_update_ops)

        with mtf.utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=10,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                hparams.model_dir,
                save_steps=1000,
                saver=saver,
                listeners=[saver_listener])

        # EVAL mode
        if mode == tf.estimator.ModeKeys.EVAL:
            tf_logits = lowering.export_to_tf_tensor(logits)
            return model.estimator_spec_eval(features, tf_logits, labels,
                                             tf_loss, restore_hook, use_tpu)

        if use_tpu:
            # TPU host call. Important: need to be called before remove_summaries()
            if hparams.tpu_enable_host_call:
                host_call = t2t_model.create_host_call(hparams.model_dir)
            else:
                host_call = None

            t2t_model.remove_summaries()
            return tpu_estimator.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                train_op=train_op,
                host_call=host_call,
                training_hooks=[restore_hook, saver_hook])
        else:
            return tf.estimator.EstimatorSpec(
                tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                train_op=train_op,
                training_chief_hooks=[restore_hook, saver_hook])
示例#9
0
    def _model_fn(input_fea, input_lab):
        """Creates a model, add summary, modes (train or eval), and hooks."""
        def _add_summary(lowering, train_or_eval, tf_loss, scalars,
                         global_step):
            """Add all summaries."""
            for k in scalars.keys():
                if not isinstance(scalars[k], tf.Tensor):
                    scalars[k] = tf.cast(
                        lowering.export_to_tf_tensor(scalars[k]), tf.float32)

            def _host_loss_summary(global_step, tf_loss, **scalars):
                """Add summary.scalar in host side."""
                gs = tf.cast(global_step, tf.int64)
                sum_loss = tf.contrib.summary.scalar(
                    '{}_loss'.format(train_or_eval), tf_loss, step=gs)
                sum_ops = [sum_loss.op]
                for description, tf_metric in scalars.iteritems():
                    sum_metric = tf.contrib.summary.scalar('{}_{}'.format(
                        train_or_eval, description),
                                                           tf_metric,
                                                           step=gs)
                    sum_ops.append(sum_metric)
                with tf.control_dependencies(sum_ops):
                    return tf.identity(tf_loss)

            # Cast the global step to tf.int32, since
            # outside_compilation does not support tf.int64.
            tf_loss = tpu.outside_compilation(_host_loss_summary,
                                              tf.cast(global_step, tf.int32),
                                              tf_loss, **scalars)

            return tf_loss

        global_step = tf.train.get_or_create_global_step()
        graph = mtf.Graph()

        # Worker 0 caches all the TPU binaries.
        replica_cache_size = 300 * 1024 * 1024  # 300M per replica.
        worker0_mem = replica_cache_size * 8 * num_hosts
        devices_memory_usage = [worker0_mem] + [0] * (num_hosts - 1)

        tf.logging.info('cpu_devices: {}, devices_mem: {}'.format(
            cpu_devices, devices_memory_usage))
        var_placer = mtf.utils.BalancedVariablePlacer(cpu_devices,
                                                      devices_memory_usage)

        mesh = mtf.Mesh(graph, 'my_mesh', var_placer)

        mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
        layout_rules = unet.get_layout()
        mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules,
                                                    None, d_assignment)

        with mtf.utils.outside_all_rewrites():  # Do not tpu_rewrite this part.
            preds, loss, scalars, bn_update_ops = (
                unet.unet_with_spatial_partition(mesh, train_or_eval,
                                                 input_fea, input_lab))

        if train_or_eval == 'train':
            var_grads = mtf.gradients(
                [loss], [v.outputs[0] for v in graph.trainable_variables])

            lr = FLAGS.lr * tf.pow(
                FLAGS.lr_drop_rate,
                tf.floor(
                    tf.cast(global_step, tf.float32) / FLAGS.lr_drop_steps))
            scalars['learning_rate'] = lr

            optimizer = mtf.optimize.AdafactorOptimizer(learning_rate=lr)
            update_ops = optimizer.apply_grads(var_grads,
                                               graph.trainable_variables)

            # This is where the actual tf graph got built.
            lowering = mtf.Lowering(graph, {mesh: mesh_impl})

            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            tf_update_ops.extend(
                [lowering.lowered_operation(op) for op in bn_update_ops])
            tf_update_ops_group = tf.group(tf_update_ops)

        else:  # train_or_eval == 'eval':
            preds = [mtf.anonymize(pred) for pred in preds]

            # This is where the actual tf graph got built.
            lowering = mtf.Lowering(graph, {mesh: mesh_impl})

            tf_preds = [
                tf.cast(lowering.export_to_tf_tensor(pred), tf.float32)
                for pred in preds
            ]

        tf_loss = tf.cast(lowering.export_to_tf_tensor(loss), tf.float32)
        if FLAGS.write_summary:
            tf_loss = _add_summary(lowering, train_or_eval, tf_loss, scalars,
                                   global_step)
        master_to_slice_hook = mtf.MtfRestoreHook(lowering)

        if train_or_eval == 'train':
            with mtf.utils.outside_all_rewrites():
                saver = tf.train.Saver(tf.global_variables(),
                                       save_relative_paths=True)
                tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
                saver_listener = mtf.MtfCheckpointSaverListener(lowering)
                slice_to_master_hook = tf.train.CheckpointSaverHook(
                    FLAGS.checkpoint_dir,
                    save_steps=FLAGS.save_checkpoints_steps,
                    saver=saver,
                    listeners=[saver_listener])
                captured_hooks.capture(
                    [master_to_slice_hook, slice_to_master_hook])
                return tf_update_ops_group

        else:  # train_or_eval == 'eval':
            tf_preds.extend([tf_loss, global_step])
            tf_preds_dtypes = [tf_pred.dtype for tf_pred in tf_preds]
            tf_preds_shapes = [tf_pred.shape for tf_pred in tf_preds]
            captured_hooks.capture([master_to_slice_hook, None])
            captured_output_dtypes_shapes.capture(
                [tf_preds_dtypes, tf_preds_shapes])
            return tpu_ops.outfeed_enqueue_tuple(tf_preds)
示例#10
0
  def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
    """The `model_fn` for TPUEstimator."""

    tf.logging.info("*** Features ***")
    for name in sorted(features.keys()):
      tf.logging.info("  name = %s, shape = %s" % (name, features[name].shape))

    # MTF setup.
    graph = mtf.Graph()
    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)

    ctx = params["context"]
    num_hosts = ctx.num_hosts
    host_placement_fn = ctx.tpu_host_placement_function
    device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
    tf.logging.info("device_list = %s" % device_list,)
    replica_cache_size = 300 * 1000000  # 300M per replica
    # Worker 0 caches all the TPU binaries.
    worker0_mem = replica_cache_size * ctx.num_replicas
    devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
    var_placer = mtf.utils.BalancedVariablePlacer(device_list,
                                                  devices_memeory_usage)
    mesh_devices = [""] * mesh_shape.size
    mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules,
                                                mesh_devices,
                                                ctx.device_assignment)
    mesh = mtf.Mesh(graph, "bert_mesh", var_placer)

    unique_ids = features["unique_ids"]
    input_ids = features["input_ids"]
    input_mask = features["input_mask"]
    segment_ids = features["segment_ids"]

    batch_size = input_ids.get_shape()[0].value
    batch_dim = mtf.Dimension("batch", batch_size)
    seq_length = input_ids.get_shape()[1].value
    seq_dim = mtf.Dimension("seq", seq_length)

    mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids, [batch_dim, seq_dim])
    mtf_input_mask = mtf.import_tf_tensor(mesh, input_mask,
                                          [batch_dim, seq_dim])
    mtf_segment_ids = mtf.import_tf_tensor(mesh, segment_ids,
                                           [batch_dim, seq_dim])

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)

    (start_logits, end_logits) = create_model(
        bert_config=bert_config,
        is_training=is_training,
        input_ids=mtf_input_ids,
        input_mask=mtf_input_mask,
        segment_ids=mtf_segment_ids)

    if mode == tf.estimator.ModeKeys.TRAIN:

      def compute_loss(logits, positions):
        one_hot_positions = mtf.one_hot(positions, output_dim=seq_dim)
        log_probs = mtf.log_softmax(logits, seq_dim)
        loss = -mtf.reduce_mean(
            mtf.reduce_sum(one_hot_positions * log_probs, reduced_dim=seq_dim))
        return loss

      start_positions = features["start_positions"]
      mtf_start_positions = mtf.import_tf_tensor(mesh, start_positions,
                                                 [batch_dim])
      end_positions = features["end_positions"]
      mtf_end_positions = mtf.import_tf_tensor(mesh, end_positions, [batch_dim])

      start_loss = compute_loss(start_logits, mtf_start_positions)
      end_loss = compute_loss(end_logits, mtf_end_positions)

      total_loss = (start_loss + end_loss) / 2.0
      _, update_ops = optimization_lib.create_optimizer(
          total_loss,
          learning_rate,
          num_train_steps,
          num_warmup_steps,
          max_optimized_variable_size=FLAGS.max_optimized_variable_size,
          optimizer=FLAGS.optimizer,
          clip_gradients=FLAGS.clip_gradients)
    elif mode == tf.estimator.ModeKeys.PREDICT:
      start_logits = mtf.anonymize(start_logits)
      end_logits = mtf.anonymize(end_logits)

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})

    if mode == tf.estimator.ModeKeys.TRAIN:
      tf_loss = tf.to_float(lowering.export_to_tf_tensor(total_loss))
      global_step = tf.train.get_global_step()
      tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
      tf_update_ops.append(tf.assign_add(global_step, 1))
      tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
      train_op = tf.group(tf_update_ops)

    tvars = tf.trainable_variables()
    initialized_variable_names = {}
    scaffold_fn = None
    if init_checkpoint:
      (assignment_map, initialized_variable_names
      ) = bert_lib.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
      if use_tpu:

        def tpu_scaffold():
          tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
          return tf.train.Scaffold()

        scaffold_fn = tpu_scaffold
      else:
        tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

    tf.logging.info("**** Trainable Variables ****")
    for var in tvars:
      init_string = ""
      if var.name in initialized_variable_names:
        init_string = ", *INIT_FROM_CKPT*"
      tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                      init_string)

    with mtf.utils.outside_all_rewrites():
      # Copy master variables to slices. Must be called first.
      restore_hook = mtf.MtfRestoreHook(lowering)
      if mode == tf.estimator.ModeKeys.TRAIN:
        saver = tf.train.Saver(
            tf.global_variables(),
            sharded=True,
            max_to_keep=10,
            keep_checkpoint_every_n_hours=2,
            defer_build=False,
            save_relative_paths=True)
        tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
        saver_listener = mtf.MtfCheckpointSaverListener(lowering)
        saver_hook = tf.train.CheckpointSaverHook(
            FLAGS.output_dir,
            save_steps=1000,
            saver=saver,
            listeners=[saver_listener])

        return tf.estimator.tpu.TPUEstimatorSpec(
            mode,
            loss=tf_loss,
            train_op=train_op,
            training_hooks=[restore_hook, saver_hook],
            scaffold_fn=scaffold_fn)
      elif mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            "unique_ids": unique_ids,
            "start_logits": lowering.export_to_tf_tensor(start_logits),
            "end_logits": lowering.export_to_tf_tensor(end_logits),
        }

        return tf.estimator.tpu.TPUEstimatorSpec(
            mode,
            prediction_hooks=[restore_hook],
            predictions=predictions,
            scaffold_fn=scaffold_fn)
      else:
        raise ValueError("Only TRAIN and PREDICT modes are supported: %s" %
                         (mode))
示例#11
0
文件: utils.py 项目: appcoreopc/mesh
    def my_model_fn(features, labels, mode, params=None, config=None):
        """Estimator model function.

    Args:
      features: input features dictionary
      labels: ignored
      mode: a tf.estimator.ModeKeys
      params: something
      config: something

    Returns:
      something
    """
        del labels, config
        global_step = tf.train.get_global_step()
        if use_tpu:
            ctx = params["context"]
            num_hosts = ctx.num_hosts
            host_placement_fn = ctx.tpu_host_placement_function
            device_list = [
                host_placement_fn(host_id=t) for t in range(num_hosts)
            ]
            # TODO(ylc): Better estimation of replica cache size?
            replica_cache_size = 300 * 1000000  # 300M per replica
            # Worker 0 caches all the TPU binaries.
            worker0_mem = replica_cache_size * ctx.num_replicas
            devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
            var_placer = mtf.utils.BalancedVariablePlacer(
                device_list, devices_memeory_usage)
            mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape, layout_rules, mesh_devices, ctx.device_assignment)
        else:
            var_placer = None
            mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, layout_rules, mesh_devices)

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh", var_placer)

        def _import_feature(key):
            """Import a feature from the features dictionary into a mtf.Tensor.

      Args:
        key: a string

      Returns:
        a mtf.Tensor with dtype int32 and shape [batch_dim, length_dim]
      """
            batch_dim = mtf.Dimension("batch", batch_size)
            length_dim = mtf.Dimension("length", length)
            mtf_shape = mtf.Shape([batch_dim, length_dim])
            if key not in features:
                return None
            x = tf.to_int32(features[key])
            if not use_tpu:
                x = tf.Print(x, [x],
                             "import feature %s" % key,
                             summarize=1000,
                             first_n=1)
            return mtf.import_fully_replicated(mesh, x, mtf_shape, name=key)

        # PREDICT mode
        if mode == tf.estimator.ModeKeys.PREDICT:
            inputs = _import_feature("inputs")
            if text2self:
                mtf_samples = transformer_model.sample_autoregressive(
                    inputs,
                    variable_dtype=variable_dtype,
                    temperature=temperature)
            else:
                mtf_samples = transformer_model.decode(
                    inputs,
                    variable_dtype=variable_dtype,
                    beam_size=beam_size,
                    alpha=alpha,
                    temperature=temperature)
            mtf_samples = mtf.anonymize(mtf_samples)
            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)
            outputs = lowering.export_to_tf_tensor(mtf_samples)
            predictions = {"outputs": outputs}
            return tpu_estimator.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.PREDICT,
                predictions=predictions,
                prediction_hooks=[mtf.MtfRestoreHook(lowering)])

        targets = _import_feature("targets")
        anon_targets = mtf.anonymize(targets)
        if text2self:
            _, length_dim = targets.shape
            inputs = mtf.shift(targets, offset=1, dim=length_dim, wrap=False)
        else:
            inputs = _import_feature("inputs")

        if text2self:
            position_kwargs = dict(
                sequence_id=_import_feature("targets_segmentation"),
                position=_import_feature("targets_position"),
            )
        else:
            position_kwargs = dict(
                encoder_sequence_id=_import_feature("inputs_segmentation"),
                decoder_sequence_id=_import_feature("targets_segmentation"),
                encoder_position=_import_feature("inputs_position"),
                decoder_position=_import_feature("targets_position"),
            )

        logits, loss = transformer_model.call_simple(
            inputs=inputs,
            targets=targets,
            compute_loss=True,
            mode=mode,
            variable_dtype=variable_dtype,
            **position_kwargs)

        if use_tpu and logits is not None:
            logits = mtf.anonymize(logits)

        # TRAIN mode
        if mode == tf.estimator.ModeKeys.TRAIN:
            var_grads = mtf.gradients(
                [loss], [v.outputs[0] for v in graph.trainable_variables])
            optimizer = mtf.optimize.AdafactorOptimizer()
            update_ops = optimizer.apply_grads(var_grads,
                                               graph.trainable_variables)

        lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack)

        tf_loss = lowering.export_to_tf_tensor(loss)
        tf_loss = tf.to_float(tf_loss)
        if not use_tpu:
            tf_loss = tf.Print(tf_loss,
                               [tf_loss, tf.train.get_global_step()],
                               "step, tf_loss")
        if logits and mode != tf.estimator.ModeKeys.TRAIN:
            tf_logits = lowering.export_to_tf_tensor(logits)

        if mode == tf.estimator.ModeKeys.TRAIN:
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            train_op = tf.group(tf_update_ops)

        with mtf.utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=10,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                model_dir,
                save_steps=1000,
                saver=saver,
                listeners=[saver_listener])

            if mode == tf.estimator.ModeKeys.TRAIN:
                if use_tpu:
                    return tpu_estimator.TPUEstimatorSpec(
                        mode=tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        training_hooks=[restore_hook, saver_hook])
                else:
                    return tf.estimator.EstimatorSpec(
                        tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        training_chief_hooks=[restore_hook, saver_hook])
            elif mode == tf.estimator.ModeKeys.EVAL:

                def padded_neg_log_perplexity(logits, labels):
                    weights = tf.to_float(tf.not_equal(labels, 0))
                    xent = tf.nn.sparse_softmax_cross_entropy_with_logits(
                        labels=labels, logits=logits)
                    return {
                        "neg_log_perplexity": tf.metrics.mean(-xent, weights)
                    }

                labels = lowering.export_to_tf_tensor(anon_targets)
                eval_metrics = (padded_neg_log_perplexity, [tf_logits, labels])
                return tpu_estimator.TPUEstimatorSpec(
                    tf.estimator.ModeKeys.EVAL,
                    evaluation_hooks=[restore_hook],
                    loss=tf_loss,
                    eval_metrics=eval_metrics)
示例#12
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        # MTF setup.
        graph = mtf.Graph()
        # mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
        # layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)
        if FLAGS.mode == "auto_parallel":
            mesh_shape_map = {
                1: [("processor_rows", 1)],
                2: [("processor_rows", 2)],
                4: [("processor_rows", 2), ("processor_cols", 2)],
                8: [("processor_rows", 2), ("processor_cols", 4)]
            }
        elif FLAGS.mode == "data_parallel":
            mesh_shape_map = {
                1: [("processor_rows", 1)],
                2: [("processor_rows", 2)],
                4: [("processor_rows", 4)],
                8: [("processor_rows", 8)]
            }
        else:
            raise ValueError

        mesh_shape = mesh_shape_map[FLAGS.gpu_num]
        devices = [f"gpu:{i}" for i in range(FLAGS.gpu_num)]

        var_placer = None
        mesh = mtf.Mesh(graph, "bert_mesh", var_placer)
        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        masked_lm_positions = features["masked_lm_positions"]
        masked_lm_ids = features["masked_lm_ids"]
        masked_lm_weights = features["masked_lm_weights"]
        next_sentence_labels = tf.squeeze(features["next_sentence_labels"], 1)

        batch_size = input_ids.get_shape()[0].value
        batch_dim = mtf.Dimension("batch", batch_size)

        seq_length = input_ids.get_shape()[1].value
        seq_dim = mtf.Dimension("seq", seq_length)
        max_predictions_per_seq = masked_lm_positions.get_shape()[1].value
        max_predictions_per_seq_dim = mtf.Dimension("max_pred_seq",
                                                    max_predictions_per_seq)

        mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids,
                                             [batch_dim, seq_dim])
        mtf_input_mask = mtf.import_tf_tensor(mesh, input_mask,
                                              [batch_dim, seq_dim])
        mtf_segment_ids = mtf.import_tf_tensor(mesh, segment_ids,
                                               [batch_dim, seq_dim])
        mtf_masked_lm_positions = mtf.import_tf_tensor(
            mesh, masked_lm_positions,
            [batch_dim, max_predictions_per_seq_dim])
        mtf_masked_lm_ids = mtf.import_tf_tensor(
            mesh, masked_lm_ids, [batch_dim, max_predictions_per_seq_dim])

        mtf_masked_lm_weights = mtf.import_tf_tensor(
            mesh, masked_lm_weights, [batch_dim, max_predictions_per_seq_dim])
        mtf_next_sentence_labels = mtf.import_tf_tensor(
            mesh, next_sentence_labels, [batch_dim])

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        model = bert_lib.BertModel(config=bert_config,
                                   is_training=is_training,
                                   input_ids=mtf_input_ids,
                                   input_mask=mtf_input_mask,
                                   token_type_ids=mtf_segment_ids,
                                   mesh_shape=mesh_shape)

        (masked_lm_loss, masked_lm_example_loss,
         masked_lm_logits) = model.get_masked_lm_output(
             mtf_masked_lm_positions, mtf_masked_lm_ids, mtf_masked_lm_weights)

        (next_sentence_loss, next_sentence_example_loss, next_sentence_logits
         ) = model.get_next_sentence_output(mtf_next_sentence_labels)

        extra_loss = model.get_extra_loss()

        total_loss = masked_lm_loss + next_sentence_loss
        total_loss = mtf.anonymize(total_loss)
        masked_lm_example_loss = mtf.anonymize(masked_lm_example_loss)
        masked_lm_logits = mtf.anonymize(masked_lm_logits)
        next_sentence_example_loss = mtf.anonymize(next_sentence_example_loss)
        next_sentence_logits = mtf.anonymize(next_sentence_logits)

        outputs = [total_loss]
        if FLAGS.mode == "auto_parallel":
            layout_rules = mtf.auto_mtf.layout(graph, mesh_shape, outputs)
        elif FLAGS.mode == "data_parallel":
            layout_rules = [('batch', 'processor_rows')]
        else:
            raise ValueError

        variables = graph._all_variables
        for v in variables:
            tf.logging.info(
                "[parameter] (name,shape,dtype): ({},{},{})".format(
                    v.name, v.shape, v.dtype.master_dtype))
        tf.logging.info("layout rules: {}".format(layout_rules))
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
            mesh_shape, layout_rules, devices)
        # TRAIN mode
        if mode == tf.estimator.ModeKeys.TRAIN:
            _, update_ops = optimization_lib.create_optimizer(
                total_loss + extra_loss,
                learning_rate,
                num_train_steps,
                num_warmup_steps,
                optimizer=FLAGS.optimizer,
                clip_gradients=FLAGS.clip_gradients)

        lowering = mtf.Lowering(graph, {mesh: mesh_impl})

        tf_loss = tf.to_float(lowering.export_to_tf_tensor(total_loss))

        if mode == tf.estimator.ModeKeys.TRAIN:
            global_step = tf.train.get_global_step()
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
            train_op = tf.group(tf_update_ops)

        with mtf.utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            if mode == tf.estimator.ModeKeys.TRAIN:
                saver = tf.train.Saver(tf.global_variables(),
                                       sharded=True,
                                       max_to_keep=10,
                                       keep_checkpoint_every_n_hours=2,
                                       defer_build=False,
                                       save_relative_paths=True)
                tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
                # saver_listener = mtf.MtfCheckpointSaverListener(lowering)
                # saver_hook = tf.train.CheckpointSaverHook(
                #     FLAGS.output_dir,
                #     save_steps=1000,
                #     saver=saver,
                #     listeners=[saver_listener])

                return tf.estimator.EstimatorSpec(
                    tf.estimator.ModeKeys.TRAIN,
                    loss=tf_loss,
                    train_op=train_op,
                    training_hooks=[restore_hook])
示例#13
0
    def my_model_fn(features, labels, mode, params=None, config=None):
        """Estimator model function.

    Args:
      features: input features dictionary
      labels: ignored
      mode: a tf.estimator.ModeKeys
      params: something
      config: something

    Returns:
      something
    """
        del labels, config
        global_step = tf.train.get_global_step()
        if use_tpu:
            ctx = params["context"]
            num_hosts = ctx.num_hosts
            host_placement_fn = ctx.tpu_host_placement_function
            device_list = [
                host_placement_fn(host_id=t) for t in range(num_hosts)
            ]
            # TODO(ylc): Better estimation of replica cache size?
            replica_cache_size = 300 * 1000000  # 300M per replica
            # Worker 0 caches all the TPU binaries.
            worker0_mem = replica_cache_size * ctx.num_replicas
            devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
            var_placer = mtf.utils.BalancedVariablePlacer(
                device_list, devices_memeory_usage)
            mesh_devices = [""] * mesh_shape.size
            physical_shape = list(
                params["context"].device_assignment.topology.mesh_shape)
            logical_to_physical = _logical_to_physical(physical_shape,
                                                       mesh_shape)
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape,
                layout_rules,
                mesh_devices,
                ctx.device_assignment,
                logical_to_physical=logical_to_physical)
        else:
            var_placer = None
            mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, layout_rules, mesh_devices)

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh", var_placer)

        outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size)
        batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size)
        length_dim = mtf.Dimension("length", sequence_length)
        feature_shape = mtf.Shape([outer_batch_dim, batch_dim, length_dim])

        mtf_features = {}
        for key, x in features.items():
            x = tf.to_int32(features[key])
            x = tf.reshape(x, [
                outer_batch_size, batch_size // outer_batch_size,
                sequence_length
            ])
            if not use_tpu:
                x = tf.Print(x, [x],
                             "import feature %s" % key,
                             summarize=1000,
                             first_n=1)
            mtf_features[key] = mtf.import_fully_replicated(mesh,
                                                            x,
                                                            feature_shape,
                                                            name=key)

        if mode == tf.estimator.ModeKeys.PREDICT:
            inputs = mtf_features["inputs"]
            inputs = mtf.reshape(
                inputs,
                mtf.Shape([
                    mtf.Dimension("batch", batch_size),
                    mtf.Dimension("length", sequence_length)
                ]))
            if isinstance(transformer_model, transformer.Unitransformer):
                mtf_samples = transformer_model.sample_autoregressive(
                    inputs, variable_dtype=get_variable_dtype())
            elif isinstance(
                    transformer_model,
                (transformer.Bitransformer, transformer.StudentTeacher)):
                mtf_samples = transformer_model.decode(
                    inputs, variable_dtype=get_variable_dtype())
            else:
                raise ValueError("unrecognized class")
            mtf_samples = mtf.anonymize(mtf_samples)
            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)
            outputs = lowering.export_to_tf_tensor(mtf_samples)
            predictions = {"outputs": outputs}
            return tpu_estimator.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.PREDICT,
                predictions=predictions,
                prediction_hooks=[mtf.MtfRestoreHook(lowering)])

        elif mode == tf.estimator.ModeKeys.EVAL:
            raise NotImplementedError("We don't expect to use mode == eval.")

        else:
            assert mode == tf.estimator.ModeKeys.TRAIN
            num_microbatches = serialize_num_microbatches(
                batch_dim, length_dim, mesh_shape, layout_rules)

            def model_fn(mtf_features):
                """The kind of function we need for mtf.serialize_training_step.

        Args:
          mtf_features: a dictionary
        Returns:
          a dictionary
        """
                targets = mtf_features["targets"]
                if model_type == "lm":
                    _, _, length_dim = targets.shape
                    inputs = mtf.shift(targets,
                                       offset=1,
                                       dim=length_dim,
                                       wrap=False)
                else:
                    inputs = mtf_features["inputs"]

                if isinstance(transformer_model, transformer.Unitransformer):
                    position_kwargs = dict(
                        sequence_id=mtf_features.get("targets_segmentation",
                                                     None),
                        position=mtf_features.get("targets_position", None),
                    )
                elif isinstance(transformer_model, transformer.Bitransformer
                                ) or model_type == "bi_student_teacher":
                    position_kwargs = dict(
                        encoder_sequence_id=mtf_features.get(
                            "inputs_segmentation", None),
                        decoder_sequence_id=mtf_features.get(
                            "targets_segmentation", None),
                        encoder_position=mtf_features.get(
                            "inputs_position", None),
                        decoder_position=mtf_features.get(
                            "targets_position", None),
                    )
                else:
                    raise ValueError("unrecognized class")

                logits, loss = transformer_model.call_simple(
                    inputs=inputs,
                    targets=targets,
                    compute_loss=True,
                    mode=mode,
                    variable_dtype=get_variable_dtype(),
                    **position_kwargs)
                if num_microbatches > 1:
                    loss /= float(num_microbatches)
                del logits
                return {"loss": loss}

            if num_microbatches > 1:
                var_grads, loss_dict = mtf.serialize_training_step(
                    mtf_features, model_fn, batch_dim, num_microbatches)
            else:
                loss_dict = model_fn(mtf_features)
                var_grads = mtf.gradients(
                    [loss_dict["loss"]],
                    [v.outputs[0] for v in graph.trainable_variables])

            loss = loss_dict["loss"]

            if callable(learning_rate_schedule):
                # the following happens on CPU since TPU can't handle summaries.
                with mtf.utils.outside_all_rewrites():
                    learning_rate = learning_rate_schedule(
                        step=tf.train.get_global_step())
                    tf.summary.scalar("learning_rate", learning_rate)
            else:
                learning_rate = learning_rate_schedule

            update_ops = optimizer(learning_rate=learning_rate).apply_grads(
                var_grads, graph.trainable_variables)

            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)

            tf_loss = lowering.export_to_tf_tensor(loss)
            tf_loss = tf.to_float(tf_loss)
            if not use_tpu:
                tf_loss = tf.Print(
                    tf_loss, [tf_loss, tf.train.get_global_step()],
                    "step, tf_loss")

            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            train_op = tf.group(tf_update_ops)

            if hasattr(transformer_model, "initialize"):
                with mtf.utils.outside_all_rewrites():
                    transformer_model.initialize()

            with mtf.utils.outside_all_rewrites():
                # Copy master variables to slices. Must be called first.
                restore_hook = mtf.MtfRestoreHook(lowering)
                saver = tf.train.Saver(tf.global_variables(),
                                       sharded=True,
                                       max_to_keep=keep_checkpoint_max,
                                       keep_checkpoint_every_n_hours=2,
                                       defer_build=False,
                                       save_relative_paths=True)
                tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
                saver_listener = mtf.MtfCheckpointSaverListener(lowering)
                saver_hook = tf.train.CheckpointSaverHook(
                    model_dir,
                    save_steps=save_checkpoints_steps,
                    saver=saver,
                    listeners=[saver_listener])
                gin_config_saver_hook = gin.tf.GinConfigSaverHook(
                    model_dir, summarize_config=True)

                if use_tpu:
                    if tpu_summaries:
                        tf.summary.scalar("loss", tf_loss)
                        host_call = mtf.utils.create_host_call(model_dir)
                        mtf.utils.remove_summaries()
                    else:
                        host_call = None
                    return tpu_estimator.TPUEstimatorSpec(
                        mode=tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        host_call=host_call,
                        training_hooks=[
                            restore_hook,
                            saver_hook,
                            gin_config_saver_hook,
                        ])
                else:
                    return tf.estimator.EstimatorSpec(
                        tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        training_chief_hooks=[
                            restore_hook,
                            saver_hook,
                            gin_config_saver_hook,
                        ])
示例#14
0
  def estimator_model_fn(cls,
                         hparams,
                         features,
                         labels,
                         mode,
                         config=None,
                         params=None,
                         decode_hparams=None,
                         use_tpu=False):
    hparams = copy.deepcopy(hparams)
    hparams.use_tpu = use_tpu
    # merge decode_hparams into hparams if present
    if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None:
      for k, v in six.iteritems(decode_hparams.values()):
        if hasattr(hparams, k) and getattr(hparams, k) != v:
          tf.logging.warning("Overriding hparams.%s with %s from decode_hparams"
                             % (k, v))
        setattr(hparams, k, v)

    # Instantiate model
    data_parallelism = None
    if not use_tpu and config:
      data_parallelism = config.data_parallelism
    model = cls(
        hparams,
        mode,
        data_parallelism=data_parallelism,
        decode_hparams=decode_hparams)

    global_step = tf.train.get_global_step()

    mesh_shape = mtf.convert_to_shape(hparams.mesh_shape)
    layout_rules = mtf.convert_to_layout_rules(hparams.layout)
    if use_tpu:
      ctx = params["context"]
      num_hosts = ctx.num_hosts
      host_placement_fn = ctx.tpu_host_placement_function
      device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
      # TODO(ylc): Better estimation of replica cache size?
      replica_cache_size = 300 * 1000000  # 300M per replica
      # Worker 0 caches all the TPU binaries.
      worker0_mem = replica_cache_size * ctx.num_replicas
      devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
      var_placer = mtf.utils.BalancedVariablePlacer(device_list,
                                                    devices_memeory_usage)
      mesh_devices = [""] * mesh_shape.size
      mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
          mesh_shape, layout_rules, mesh_devices, ctx.device_assignment)
    else:
      var_placer = None
      if data_parallelism is None or len(data_parallelism.ps_devices) == 1:
        mesh_devices = [""] * mesh_shape.size
      else:
        assert len(data_parallelism.ps_devices) == mesh_shape.size
        mesh_devices = data_parallelism.ps_devices
      mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
          mesh_shape, layout_rules, mesh_devices)

    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh", var_placer)
    # PREDICT mode
    if mode == tf.estimator.ModeKeys.PREDICT:
      return model.estimator_spec_predict(features, mesh, mesh_impl, use_tpu)

    logits, loss = model.mtf_model_fn(features, mesh)
    if use_tpu and logits is not None:
      logits = mtf.anonymize(logits)

    # TRAIN mode
    if mode == tf.estimator.ModeKeys.TRAIN:
      var_grads = mtf.gradients(
          [loss], [v.outputs[0] for v in graph.trainable_variables])
      lr = learning_rate.learning_rate_schedule(hparams)
      tf.summary.scalar("learning_rate", lr)
      mtf_lr = mtf.import_tf_tensor(
          mesh, tf.convert_to_tensor(lr, dtype=tf.float32), mtf.Shape([]))
      optimizer = mtf.optimize.make_optimizer(hparams, mtf_lr)
      update_ops = []
      for grad, var in zip(var_grads, graph.trainable_variables):
        update_ops.extend(optimizer.apply_grad(grad, var))

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})

    tf_loss = lowering.export_to_tf_tensor(loss)
    tf_loss = tf.to_float(tf_loss)
    if logits and mode != tf.estimator.ModeKeys.TRAIN:
      tf_logits = lowering.export_to_tf_tensor(logits)

    if mode == tf.estimator.ModeKeys.TRAIN:
      tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
      tf_update_ops.append(tf.assign_add(global_step, 1))
      # tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
      train_op = tf.group(tf_update_ops)

    with mtf.utils.outside_all_rewrites():
      # Copy master variables to slices. Must be called first.
      restore_hook = mtf.MtfRestoreHook(lowering)
      saver = tf.train.Saver(
          tf.global_variables(),
          sharded=True,
          max_to_keep=10,
          keep_checkpoint_every_n_hours=2,
          defer_build=False,
          save_relative_paths=True)
      tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
      saver_listener = mtf.MtfCheckpointSaverListener(lowering)
      saver_hook = tf.train.CheckpointSaverHook(
          hparams.model_dir,
          save_steps=1000,
          saver=saver,
          listeners=[saver_listener])

    # EVAL mode
    if mode == tf.estimator.ModeKeys.EVAL:
      tf_logits = lowering.export_to_tf_tensor(logits)
      return model.estimator_spec_eval(features, tf_logits, labels, tf_loss,
                                       restore_hook, use_tpu)

    if use_tpu:
      # TPU host call. Important: need to be called before remove_summaries()
      if hparams.tpu_enable_host_call:
        host_call = t2t_model.create_host_call(hparams.model_dir)
      else:
        host_call = None

      t2t_model.remove_summaries()
      return tpu_estimator.TPUEstimatorSpec(
          mode=tf.estimator.ModeKeys.TRAIN,
          loss=tf_loss,
          train_op=train_op,
          host_call=host_call,
          training_hooks=[restore_hook, saver_hook])
    else:
      return tf.estimator.EstimatorSpec(
          tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op,
          training_chief_hooks=[restore_hook, saver_hook])
示例#15
0
def model_fn(features, labels, mode, params):
    """A model is called by TpuEstimator."""
    del labels
    global_step = tf.train.get_global_step()
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, 'my_mesh')
    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    mesh_devices = [''] * mesh_shape.size
    mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
        mesh_shape, mtf.convert_to_layout_rules(FLAGS.layout), mesh_devices,
        params['context'].device_assignment)
    with mtf.utils.outside_all_rewrites():
        logits, loss = toy_model(features, mesh)

    # TRAIN mode
    if mode == tf.estimator.ModeKeys.TRAIN:
        var_grads = mtf.gradients(
            [loss], [v.outputs[0] for v in graph.trainable_variables])
        optimizer = mtf.optimize.AdafactorOptimizer()
        update_ops = []
        for grad, var in zip(var_grads, graph.trainable_variables):
            update_ops.extend(optimizer.apply_grad(grad, var))
    else:
        # for now, we can only export fully-replicated tensors.
        fully_replicated_logits = mtf.anonymize(logits)

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})

    tf_loss = lowering.export_to_tf_tensor(loss)

    if mode == tf.estimator.ModeKeys.TRAIN:
        tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
        tf_update_ops.append(tf.assign_add(global_step, 1))
        tf.logging.info('tf_update_ops: {}'.format(tf_update_ops))
        train_op = tf.group(tf_update_ops)
    else:
        tf_logits = lowering.export_to_tf_tensor(fully_replicated_logits)

    with mtf.utils.outside_all_rewrites():
        # Copy master variables to slices. Must be called first.
        restore_hook = mtf.MtfRestoreHook(lowering)
        if mode == tf.estimator.ModeKeys.TRAIN:
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=10,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                FLAGS.model_dir,
                save_steps=1000,
                saver=saver,
                listeners=[saver_listener])

            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                train_op=train_op,
                training_hooks=[restore_hook, saver_hook])
        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(tf_logits):
                mean_logitss = tf.metrics.mean(tf_logits)
                return {'mean_logitss': mean_logitss}

            eval_metrics = (metric_fn, [tf_logits])

            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.EVAL,
                evaluation_hooks=[restore_hook],
                loss=tf_loss,
                eval_metrics=eval_metrics)
示例#16
0
  def _model_fn(input_fea, input_lab):
    """Creates a model, add summary, modes (train or eval), and hooks."""

    # input_fea and input_lab should be a list (laid_out_tensors).
    if not isinstance(input_fea, list):
      input_fea = [input_fea]
    if not isinstance(input_lab, list):
      input_lab = [input_lab]

    def _add_summary(lowering, train_or_eval, tf_loss, scalars, global_step):
      """Add all summaries."""
      for k in scalars.keys():
        if not isinstance(scalars[k], tf.Tensor):
          scalars[k] = tf.cast(
              lowering.export_to_tf_tensor(scalars[k]), tf.float32)

      def _host_loss_summary(global_step, tf_loss, **scalars):
        """Add summary.scalar in host side."""
        gs = tf.cast(global_step, tf.int64)
        sum_loss = contrib_summary.scalar(
            '{}_loss'.format(train_or_eval), tf_loss, step=gs)
        sum_ops = [sum_loss.op]
        for description, tf_metric in scalars.iteritems():
          sum_metric = contrib_summary.scalar(
              '{}_{}'.format(train_or_eval, description), tf_metric, step=gs)
          sum_ops.append(sum_metric)
        with tf.control_dependencies(sum_ops):
          return tf.identity(tf_loss)

      if FLAGS.use_tpu:
        # Cast the global step to tf.int32, since
        # outside_compilation does not support tf.int64.
        tf_loss = tpu.outside_compilation(
            _host_loss_summary,
            tf.cast(global_step, tf.int32),
            tf_loss,
            **scalars)
      else:
        tf_loss = _host_loss_summary(
            tf.cast(global_step, tf.int32),
            tf_loss,
            **scalars)

      return tf_loss

    global_step = tf.train.get_or_create_global_step()
    graph, mesh, mesh_impl = mesh_context.create_graph_mesh_and_mesh_impl()

    with mtf.utils.outside_all_rewrites():
      # Do not tpu_rewrite this part. Inside this unet, If you use Tensorflow,
      # instead of Mesh-Tensorflor, it will cause host to tpu send/rec.
      preds, loss, scalars, bn_update_ops = (
          unet.unet_with_spatial_partition(
              mesh, mesh_impl, train_or_eval, input_fea, input_lab))

    if train_or_eval == 'train':
      var_grads = mtf.gradients(
          [loss], [v.outputs[0] for v in graph.trainable_variables])

      lr = FLAGS.lr * tf.pow(
          FLAGS.lr_drop_rate,
          tf.floor(tf.cast(global_step, tf.float32) / FLAGS.lr_drop_steps))
      scalars['learning_rate'] = lr

      optimizer = mtf.optimize.AdafactorOptimizer(learning_rate=lr)
      update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables)

      # This is where the actual tf graph got built.
      lowering = mtf.Lowering(graph, {mesh: mesh_impl})

      tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
      tf_update_ops.append(tf.assign_add(global_step, 1))
      tf_update_ops.extend(
          [lowering.lowered_operation(op) for op in bn_update_ops])

    else:  # train_or_eval == 'eval':
      preds = [mtf.anonymize(pred) for pred in preds]

      # This is where the actual tf graph got built.
      lowering = mtf.Lowering(graph, {mesh: mesh_impl})

      tf_preds = [tf.cast(
          lowering.export_to_tf_tensor(pred), tf.float32) for pred in preds]

    tf_loss = tf.cast(lowering.export_to_tf_tensor(loss), tf.float32)
    if FLAGS.write_summary:
      tf_loss = _add_summary(
          lowering, train_or_eval, tf_loss, scalars, global_step)
    master_to_slice_hook = mtf.MtfRestoreHook(lowering)

    if train_or_eval == 'train':
      with mtf.utils.outside_all_rewrites():
        saver = tf.train.Saver(tf.global_variables(),
                               save_relative_paths=True)
        tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
        saver_listener = mtf.MtfCheckpointSaverListener(lowering)
        slice_to_master_hook = tf.train.CheckpointSaverHook(
            FLAGS.checkpoint_dir,
            save_steps=FLAGS.save_checkpoints_steps,
            saver=saver, listeners=[saver_listener])
        captured_hooks.capture([master_to_slice_hook, slice_to_master_hook])
        return tf.group([tf_loss] + tf_update_ops)

    else:  # train_or_eval == 'eval':
      if FLAGS.use_tpu:
        tf_preds.extend([tf_loss, global_step])
        tf_preds_dtypes = [tf_pred.dtype for tf_pred in tf_preds]
        tf_preds_shapes = [tf_pred.shape for tf_pred in tf_preds]
        captured_hooks.capture([master_to_slice_hook, None])
        captured_output_dtypes_shapes.capture(
            [tf_preds_dtypes, tf_preds_shapes])
        return tpu_ops.outfeed_enqueue_tuple(tf_preds)

      else:
        tf_preds.extend([tf_loss, global_step])
        captured_hooks.capture([master_to_slice_hook, None])
        return tf_preds
示例#17
0
    def my_model_fn(features, labels, mode, params=None, config=None):
        """Estimator model function.

    Args:
      features: input features dictionary
      labels: ignored
      mode: a tf.estimator.ModeKeys
      params: something
      config: something

    Returns:
      something
    """
        del labels, config
        global_step = tf.train.get_global_step()
        if use_tpu:
            ctx = params["context"]
            num_hosts = ctx.num_hosts
            host_placement_fn = ctx.tpu_host_placement_function
            device_list = [
                host_placement_fn(host_id=t) for t in range(num_hosts)
            ]
            # TODO(ylc): Better estimation of replica cache size?
            replica_cache_size = 300 * 1000000  # 300M per replica
            # Worker 0 caches all the TPU binaries.
            worker0_mem = replica_cache_size * ctx.num_replicas
            devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
            var_placer = mtf.utils.BalancedVariablePlacer(
                device_list, devices_memeory_usage)
            mesh_devices = [""] * mesh_shape.size
            physical_shape = list(
                params["context"].device_assignment.topology.mesh_shape)
            logical_to_physical = _logical_to_physical(physical_shape,
                                                       mesh_shape)
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape,
                layout_rules,
                mesh_devices,
                ctx.device_assignment,
                logical_to_physical=logical_to_physical)
        else:
            var_placer = None
            mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, layout_rules, mesh_devices)

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh", var_placer)

        def _import_feature(key, allow_missing=False):
            """Import a feature from the features dictionary into a mtf.Tensor.

      Args:
        key: a string
        allow_missing: a boolean

      Returns:
        a mtf.Tensor with dtype int32 and shape [batch_dim, length_dim]
      """
            outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size)
            batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size)
            length_dim = mtf.Dimension("length", sequence_length)

            mtf_shape = mtf.Shape([outer_batch_dim, batch_dim, length_dim])
            if key not in features:
                if allow_missing:
                    return None
                else:
                    raise ValueError("feature not found %s - features %s = " %
                                     (key, features))
            tf.logging.info("Import feature %s: %s" % (key, features[key]))

            x = tf.to_int32(features[key])
            x = tf.reshape(
                x, [outer_batch_size, batch_size // outer_batch_size, -1])

            if not use_tpu:
                x = tf.Print(x, [x],
                             "import feature %s" % key,
                             summarize=1000,
                             first_n=1)
            return mtf.import_fully_replicated(mesh, x, mtf_shape, name=key)

        if mode == tf.estimator.ModeKeys.PREDICT:
            inputs = _import_feature("inputs")
            inputs = mtf.reshape(
                inputs,
                mtf.Shape([
                    mtf.Dimension("batch", batch_size),
                    mtf.Dimension("length", sequence_length)
                ]))
            if isinstance(transformer_model, transformer.Unitransformer):
                mtf_samples = transformer_model.sample_autoregressive(
                    inputs, variable_dtype=get_variable_dtype())
            elif isinstance(transformer_model, transformer.Bitransformer):
                mtf_samples = transformer_model.decode(
                    inputs, variable_dtype=get_variable_dtype())
            else:
                raise ValueError("unrecognized class")
            mtf_samples = mtf.anonymize(mtf_samples)
            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)
            outputs = lowering.export_to_tf_tensor(mtf_samples)
            predictions = {"outputs": outputs}
            return tpu_estimator.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.PREDICT,
                predictions=predictions,
                prediction_hooks=[mtf.MtfRestoreHook(lowering)])

        targets = _import_feature("targets")
        anon_targets = mtf.anonymize(targets)
        if model_type == "lm":
            _, length_dim = targets.shape
            inputs = mtf.shift(targets, offset=1, dim=length_dim, wrap=False)
        else:
            inputs = _import_feature("inputs")

        if mode == tf.estimator.ModeKeys.EVAL:
            if isinstance(transformer_model, transformer.Unitransformer):
                mtf_samples = transformer_model.sample_autoregressive(
                    inputs, variable_dtype=get_variable_dtype())
            elif isinstance(transformer_model, transformer.Bitransformer):
                mtf_samples = transformer_model.decode(
                    inputs, variable_dtype=get_variable_dtype())
            else:
                raise ValueError("unrecognized class")
            mtf_samples = mtf.anonymize(mtf_samples)
            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)
            outputs = lowering.export_to_tf_tensor(mtf_samples)
            labels = lowering.export_to_tf_tensor(anon_targets)
            restore_hook = mtf.MtfRestoreHook(lowering)

            # metric_names becomes locally scoped if we simply assign
            # ["padded_neg_log_perplexity"] to it conditioned on if it's None.
            local_metric_names = metric_names or ["token_accuracy"]

            def metric_fn(labels, outputs):
                return get_metric_fns(local_metric_names, labels, outputs)

            eval_metrics = (metric_fn, [labels, outputs])
            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.EVAL,
                # Unfortunately TPUEstimatorSpec requires us to provide a value for
                # loss when in EVAL mode. Since we are sampling or decoding from the
                # model, we don't have a loss to report.
                loss=tf.constant(0.),
                evaluation_hooks=[restore_hook],
                eval_metrics=eval_metrics)

        if isinstance(transformer_model, transformer.Unitransformer):
            position_kwargs = dict(
                sequence_id=_import_feature("targets_segmentation", True),
                position=_import_feature("targets_position", True),
            )
        elif isinstance(transformer_model, transformer.Bitransformer):
            position_kwargs = dict(
                encoder_sequence_id=_import_feature("inputs_segmentation",
                                                    True),
                decoder_sequence_id=_import_feature("targets_segmentation",
                                                    True),
                encoder_position=_import_feature("inputs_position", True),
                decoder_position=_import_feature("targets_position", True),
            )
        else:
            raise ValueError("unrecognized class")

        logits, loss = transformer_model.call_simple(
            inputs=inputs,
            targets=targets,
            compute_loss=True,
            mode=mode,
            variable_dtype=get_variable_dtype(),
            **position_kwargs)

        if use_tpu and logits is not None:
            logits = mtf.anonymize(logits)

        # TRAIN mode
        if mode == tf.estimator.ModeKeys.TRAIN:
            var_grads = mtf.gradients(
                [loss], [v.outputs[0] for v in graph.trainable_variables])
            optimizer = mtf.optimize.AdafactorOptimizer(
                learning_rate=learning_rate)
            update_ops = optimizer.apply_grads(var_grads,
                                               graph.trainable_variables)

        lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack)

        tf_loss = lowering.export_to_tf_tensor(loss)
        tf_loss = tf.to_float(tf_loss)
        if not use_tpu:
            tf_loss = tf.Print(tf_loss,
                               [tf_loss, tf.train.get_global_step()],
                               "step, tf_loss")

        if mode == tf.estimator.ModeKeys.TRAIN:
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            train_op = tf.group(tf_update_ops)

        with mtf.utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=checkpoints_to_keep,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                model_dir,
                save_steps=save_steps,
                saver=saver,
                listeners=[saver_listener])
            gin_config_saver_hook = gin.tf.GinConfigSaverHook(
                model_dir, summarize_config=True)

            if mode == tf.estimator.ModeKeys.TRAIN:
                if use_tpu:
                    return tpu_estimator.TPUEstimatorSpec(
                        mode=tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        training_hooks=[
                            restore_hook,
                            saver_hook,
                            gin_config_saver_hook,
                        ])
                else:
                    return tf.estimator.EstimatorSpec(
                        tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        training_chief_hooks=[
                            restore_hook,
                            saver_hook,
                            gin_config_saver_hook,
                        ])
示例#18
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        # MTF setup.
        graph = mtf.Graph()
        mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
        layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)

        ctx = params["context"]
        num_hosts = ctx.num_hosts
        host_placement_fn = ctx.tpu_host_placement_function
        device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
        tf.logging.info("device_list = %s" % device_list, )
        replica_cache_size = 300 * 1000000  # 300M per replica
        # Worker 0 caches all the TPU binaries.
        worker0_mem = replica_cache_size * ctx.num_replicas
        devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
        var_placer = mtf.utils.BalancedVariablePlacer(device_list,
                                                      devices_memeory_usage)
        mesh_devices = [""] * mesh_shape.size
        physical_shape = list(ctx.device_assignment.topology.mesh_shape)
        logical_to_physical = mtf.simd_mesh_impl.auto_logical_to_physical_tpu(
            mesh_shape.to_integer_list, physical_shape)
        mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
            mesh_shape,
            layout_rules,
            mesh_devices,
            ctx.device_assignment,
            logical_to_physical=logical_to_physical)
        mesh = mtf.Mesh(graph, "bert_mesh", var_placer)

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        masked_lm_positions = features["masked_lm_positions"]
        masked_lm_ids = features["masked_lm_ids"]
        masked_lm_weights = features["masked_lm_weights"]
        next_sentence_labels = tf.squeeze(features["next_sentence_labels"], 1)

        batch_size = input_ids.get_shape()[0].value
        batch_dim = mtf.Dimension("batch", batch_size)

        seq_length = input_ids.get_shape()[1].value
        seq_dim = mtf.Dimension("seq", seq_length)
        max_predictions_per_seq = masked_lm_positions.get_shape()[1].value
        max_predictions_per_seq_dim = mtf.Dimension("max_pred_seq",
                                                    max_predictions_per_seq)

        mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids,
                                             [batch_dim, seq_dim])
        mtf_input_mask = mtf.import_tf_tensor(mesh, input_mask,
                                              [batch_dim, seq_dim])
        mtf_segment_ids = mtf.import_tf_tensor(mesh, segment_ids,
                                               [batch_dim, seq_dim])
        mtf_masked_lm_positions = mtf.import_tf_tensor(
            mesh, masked_lm_positions,
            [batch_dim, max_predictions_per_seq_dim])
        mtf_masked_lm_ids = mtf.import_tf_tensor(
            mesh, masked_lm_ids, [batch_dim, max_predictions_per_seq_dim])

        mtf_masked_lm_weights = mtf.import_tf_tensor(
            mesh, masked_lm_weights, [batch_dim, max_predictions_per_seq_dim])
        mtf_next_sentence_labels = mtf.import_tf_tensor(
            mesh, next_sentence_labels, [batch_dim])

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        model = bert_lib.BertModel(config=bert_config,
                                   is_training=is_training,
                                   input_ids=mtf_input_ids,
                                   input_mask=mtf_input_mask,
                                   token_type_ids=mtf_segment_ids,
                                   layout=layout_rules,
                                   mesh_shape=mesh_shape)

        (masked_lm_loss, masked_lm_example_loss,
         masked_lm_logits) = model.get_masked_lm_output(
             mtf_masked_lm_positions, mtf_masked_lm_ids, mtf_masked_lm_weights)

        (next_sentence_loss, next_sentence_example_loss, next_sentence_logits
         ) = model.get_next_sentence_output(mtf_next_sentence_labels)

        extra_loss = model.get_extra_loss()

        total_loss = masked_lm_loss + next_sentence_loss
        total_loss = mtf.anonymize(total_loss)
        masked_lm_example_loss = mtf.anonymize(masked_lm_example_loss)
        masked_lm_logits = mtf.anonymize(masked_lm_logits)
        next_sentence_example_loss = mtf.anonymize(next_sentence_example_loss)
        next_sentence_logits = mtf.anonymize(next_sentence_logits)

        # TRAIN mode
        if mode == tf.estimator.ModeKeys.TRAIN:
            _, update_ops = optimization_lib.create_optimizer(
                total_loss + extra_loss,
                learning_rate,
                num_train_steps,
                num_warmup_steps,
                optimizer=FLAGS.optimizer,
                clip_gradients=FLAGS.clip_gradients)

        lowering = mtf.Lowering(graph, {mesh: mesh_impl})

        tf_loss = tf.to_float(lowering.export_to_tf_tensor(total_loss))

        if mode == tf.estimator.ModeKeys.TRAIN:
            global_step = tf.train.get_global_step()
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
            train_op = tf.group(tf_update_ops)
        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(masked_lm_example_loss, masked_lm_logits,
                          masked_lm_ids, masked_lm_weights,
                          next_sentence_example_loss, next_sentence_logits,
                          next_sentence_labels):
                """Computes the loss and accuracy of the model."""
                masked_lm_logits = tf.reshape(masked_lm_logits,
                                              [-1, masked_lm_logits.shape[-1]])
                masked_lm_predictions = tf.argmax(masked_lm_logits,
                                                  axis=-1,
                                                  output_type=tf.int32)
                masked_lm_example_loss = tf.reshape(masked_lm_example_loss,
                                                    [-1])
                masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
                masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
                masked_lm_accuracy = tf.metrics.accuracy(
                    labels=masked_lm_ids,
                    predictions=masked_lm_predictions,
                    weights=masked_lm_weights)
                masked_lm_mean_loss = tf.metrics.mean(
                    values=masked_lm_example_loss, weights=masked_lm_weights)

                next_sentence_logits = tf.reshape(
                    next_sentence_logits, [-1, next_sentence_logits.shape[-1]])
                next_sentence_predictions = tf.argmax(next_sentence_logits,
                                                      axis=-1,
                                                      output_type=tf.int32)
                next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
                next_sentence_accuracy = tf.metrics.accuracy(
                    labels=next_sentence_labels,
                    predictions=next_sentence_predictions)
                next_sentence_mean_loss = tf.metrics.mean(
                    values=next_sentence_example_loss)

                return {
                    "masked_lm_accuracy": masked_lm_accuracy,
                    "masked_lm_loss": masked_lm_mean_loss,
                    "next_sentence_accuracy": next_sentence_accuracy,
                    "next_sentence_loss": next_sentence_mean_loss,
                }

            eval_metrics = (metric_fn, [
                lowering.export_to_tf_tensor(masked_lm_example_loss),
                lowering.export_to_tf_tensor(masked_lm_logits), masked_lm_ids,
                masked_lm_weights,
                lowering.export_to_tf_tensor(next_sentence_example_loss),
                lowering.export_to_tf_tensor(next_sentence_logits),
                next_sentence_labels
            ])

        with mtf.utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            if mode == tf.estimator.ModeKeys.TRAIN:
                saver = tf.train.Saver(tf.global_variables(),
                                       sharded=True,
                                       max_to_keep=10,
                                       keep_checkpoint_every_n_hours=2,
                                       defer_build=False,
                                       save_relative_paths=True)
                tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
                saver_listener = mtf.MtfCheckpointSaverListener(lowering)
                saver_hook = tf.train.CheckpointSaverHook(
                    FLAGS.output_dir,
                    save_steps=1000,
                    saver=saver,
                    listeners=[saver_listener])

                return tf.estimator.tpu.TPUEstimatorSpec(
                    tf.estimator.ModeKeys.TRAIN,
                    loss=tf_loss,
                    train_op=train_op,
                    training_hooks=[restore_hook, saver_hook])
            elif mode == tf.estimator.ModeKeys.EVAL:
                return tf.estimator.tpu.TPUEstimatorSpec(
                    tf.estimator.ModeKeys.EVAL,
                    evaluation_hooks=[restore_hook],
                    loss=tf_loss,
                    eval_metrics=eval_metrics)
示例#19
0
文件: toy.py 项目: NeuroArchitect/lm
    def __call__(self, features, labels, mode, params):  # this is the model_fn
        """A model is called by TpuEstimator."""
        del labels
        global_step = tf.train.get_global_step()

        # Graph setup
        graph = mtf.Graph()
        mesh_shape = mtf.convert_to_shape(self.mesh_shape)
        layout_rules = mtf.convert_to_layout_rules(self.layout)
        if params["use_tpu"]:
            ctx = params["context"]
            num_hosts = ctx.num_hosts
            host_placement_fn = ctx.tpu_host_placement_function
            device_list = [
                host_placement_fn(host_id=t) for t in range(num_hosts)
            ]
            # Worker 0 caches all the TPU binaries.
            replica_cache_size = 300 * 1024 * 1024  # 300M per replica.
            worker0_mem = replica_cache_size * 8 * num_hosts
            devices_memory_usage = [worker0_mem] + [0] * (num_hosts - 1)
            var_placer = mtf.utils.BalancedVariablePlacer(
                device_list, devices_memory_usage)
            mesh = mtf.Mesh(graph, "my_mesh", var_placer)
            mesh_devices = [""] * mesh_shape.size

            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape, layout_rules, mesh_devices, devices_memory_usage)
        else:
            var_placer = None
            mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, layout_rules, mesh_devices)

            mesh = mtf.Mesh(graph, "my_mesh", var_placer)

        # RUN Model
        with mtf.utils.outside_all_rewrites():
            logits, loss = self.model(mesh, features, params)

        # TRAIN mode
        if mode == tf.estimator.ModeKeys.TRAIN:
            var_grads = mtf.gradients(
                [loss], [v.outputs[0] for v in graph.trainable_variables])
            if self.optimizer == "Adafactor":
                optimizer = mtf.optimize.AdafactorOptimizer()
            else:
                assert self.optimizer == "SGD"
                optimizer = mtf.optimize.SgdOptimizer(
                    learning_rate=self.learning_rate)
                update_ops = optimizer.apply_grads(var_grads,
                                                   graph.trainable_variables)
        else:
            # for now, we can only export fully-replicated tensors.
            fully_replicated_logits = mtf.anonymize(logits)

        # covert back to tensorflow format
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        tf_loss = tf.to_float(lowering.export_to_tf_tensor(loss))
        if mode == tf.estimator.ModeKeys.TRAIN:
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
            train_op = tf.group(tf_update_ops)
        else:
            tf_logits = lowering.export_to_tf_tensor(fully_replicated_logits)

        # create estimator
        with mtf.utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            if mode == tf.estimator.ModeKeys.TRAIN:
                saver = tf.train.Saver(
                    tf.global_variables(),
                    sharded=True,
                    max_to_keep=10,
                    keep_checkpoint_every_n_hours=2,
                    defer_build=False,
                    save_relative_paths=True,
                )
                tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
                saver_listener = mtf.MtfCheckpointSaverListener(lowering)
                saver_hook = tf.train.CheckpointSaverHook(
                    self.model_dir,
                    save_steps=1000,
                    saver=saver,
                    listeners=[saver_listener],
                )

                return tpu_estimator.TPUEstimatorSpec(
                    tf.estimator.ModeKeys.TRAIN,
                    loss=tf_loss,
                    train_op=train_op,
                    training_hooks=[restore_hook, saver_hook],
                )
            elif mode == tf.estimator.ModeKeys.EVAL:

                def metric_fn(tf_logits):
                    mean_logits = tf.metrics.mean(tf_logits)
                    return {"mean_logits": mean_logits}

                eval_metrics = (metric_fn, [tf_logits])
                return tpu_estimator.TPUEstimatorSpec(
                    tf.estimator.ModeKeys.EVAL,
                    evaluation_hooks=[restore_hook],
                    loss=tf_loss,
                    eval_metrics=eval_metrics,
                )
            elif mode == tf.estimator.ModeKeys.PREDICT:
                return tpu_estimator.TPUEstimatorSpec(
                    tf.estimator.ModeKeys.PREDICT,
                    evaluation_hooks=[restore_hook],
                    loss=None,
                    eval_metrics=eval_metrics,
                )

        @property
        def dense_initializer(self):
            if self.config.initializer_range:
                return tf.truncated_normal_initializer(
                    stddev=self.config.initializer_range)
            else:
                return mtf.layers.VarianceScalingInitializer(scale=0.4)

        @property
        def embedding_initializer(self):
            initializer = self.dense_initializer
            if isinstance(initializer, mtf.layers.DenseInitializer):
                # embedding matrix is also used as classifier weight matrix.
                # scale it appropriately.
                return initializer(reduced_dims=[self.model_dim],
                                   new_dims=[self.vocab_dim])
            else:
                return initializer

        @property
        def num_hidden_layers(self):
            return self.config.num_hidden_layers

        def normalize(self, x, reduce_dim):
            return nn.layer_norm(
                x,
                reduce_dim,
                subtract_mean=self.config.use_bias,
                use_bias=self.config.use_bias,
            )

        def model(self, mesh, x, y, params):
            # x :: [batch, io, vocab]

            if params["precision"] == "bfloat16":
                dtype = tf.bfloat16
                # master has type float32, slice and activation have type bfloat16
                variable_dtype = mtf.VariableDType(tf.float32, tf.bfloat16,
                                                   tf.bfloat16)
            else:
                dtype = tf.float32
                # master, slice and activate have all float16
                variable_dtype = mtf.VariableDType(tf.float32, tf.float32,
                                                   tf.float32)

            # Build the actual model
            batch_dim = mtf.Dimension("batch", params["batch_size"])
            vocab_dim = mtf.Dimension("vocab", params["vocab_size"])
            io_dim = mtf.Dimension("sequence", params["io"])
            io_chan_dim = mtf.Dimension("io", params["io_channels"])

            # from input to mtf
            x = mtf.import_tf_tensor(mesh, x,
                                     mtf.Shape([batch_dim, io_dim, vocab_dim]))

            # Embeddings
            with tf.variable_scope(scope="toy", default_name="seq2seq"):
                with tf.variable_scope("embeddings"):
                    # Perform embedding lookup on the word ids.
                    embedding_table = mtf.get_variable(
                        mesh,
                        "word_embeddings",
                        mtf.Shape([vocab_dim, io_chan_dim]),
                        initializer=self.embedding_initializer,
                    )

                    word_embedding_output = mtf.gather(
                        embedding_table,
                        x,
                        dim=vocab_dim,
                        output_shape=io_chan_dim)

                    # Add positional embeddings and token type embeddings, then layer
                    # normalize and perform dropout.
                    embedding_output = word_embedding_output

                    pos_embedding = mtf.get_variable(
                        mesh,
                        "pos_embeddings",
                        mtf.Shape([io_dim, io_chan_dim]),
                        initializer=self.embedding_initializer,
                    )
                    embedding_output = self.normalize(embedding_output)
                    embedding_output = mtf.dropout(
                        embedding_output,
                        keep_prob=1.0 - self.config.layer_output_dropout_prob,
                    )

                # shift token by pos embeddings
                x = word_embedding_output + pos_embedding
                x = mtf.cast(x, variable_dtype.activation_dtype)

                h = x
                for lnum in range(1, self.num_hidden_layers + 2):
                    if lnum + 1 == self.num_hidden_layers + 2:
                        # output layer
                        dim = io_dim
                    elif lnum % 2 == 0:
                        dim = mtf.Dimension("hidden_even", io_chan_dim)
                    else:
                        dim = mtf.Dimension("hidden_odd", io_chan_dim)
                        h = mtf.layers.dense(
                            h,
                            dim,
                            use_bias=False,
                            master_dtype=variable_dtype.master_dtype,
                            slice_dtype=variable_dtype.slice_dtype,
                            name="layer_%d" % lnum,
                        )

                prediction = h
                # project back to token dimensions

                # compute the mean quare loss between the input and the output
                loss = mtf.reduce_mean(mtf.square(y - prediction))
                return prediction, loss