示例#1
0
def eq_cifar_fn(x, output_dim=10, trainable=True):
    gconv_indices, gconv_shape_info, w_shape = gconv2d_util(h_input='Z2',
                                                            h_output='C4',
                                                            in_channels=3,
                                                            out_channels=8,
                                                            ksize=3)
    w = tf.get_variable('w1', shape=w_shape)

    conv1 = gconv2d(input=x,
                    filter=w,
                    strides=[1, 2, 2, 1],
                    padding='SAME',
                    gconv_indices=gconv_indices,
                    gconv_shape_info=gconv_shape_info)
    tf.add_to_collection('conv_output1', conv1)
    pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)
    gconv_indices, gconv_shape_info, w_shape = gconv2d_util(h_input='C4',
                                                            h_output='C4',
                                                            in_channels=16,
                                                            out_channels=32,
                                                            ksize=5)
    w = tf.get_variable('w2', shape=w_shape)
    conv2 = gconv2d(input=conv1,
                    filter=w,
                    strides=[1, 2, 2, 1],
                    padding='SAME',
                    gconv_indices=gconv_indices,
                    gconv_shape_info=gconv_shape_info)
    pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)
    gconv_indices, gconv_shape_info, w_shape = gconv2d_util(h_input='C4',
                                                            h_output='C4',
                                                            in_channels=8,
                                                            out_channels=2,
                                                            ksize=5)
    w = tf.get_variable('w3', shape=w_shape)
    conv3 = gconv2d(input=conv2,
                    filter=w,
                    strides=[1, 1, 1, 1],
                    padding='SAME',
                    gconv_indices=gconv_indices,
                    gconv_shape_info=gconv_shape_info)
    conv3 = tf.reshape(conv3,
                       conv3.get_shape().as_list()[:3] + [4] + [out_channels])
    conv3 = tf.reduce_mean(conv3, axis=3)
    pool3 = tf.layers.max_pooling2d(inputs=conv3, pool_size=[2, 2], strides=2)
    pool3_flat = tf.layers.flatten(pool3)
    u = pool3_flat
    u = tf.layers.dense(inputs=pool3_flat,
                        units=output_dim,
                        activation=tf.nn.relu,
                        trainable=trainable)
    tf.add_to_collection('conv_output2', conv2)
    return u
示例#2
0
def eq_cnn_fn(x, output_dim=10, trainable=True, group='C4', num_filters=2):
    nchannels = x.shape[3]
    gconv_indices, gconv_shape_info, w_shape = gconv2d_util(
        h_input='Z2',
        h_output='C4',
        in_channels=nchannels,
        out_channels=2,
        ksize=5)
    w = tf.get_variable('w1', shape=w_shape)

    conv1 = gconv2d(input=x,
                    filter=w,
                    strides=[1, 1, 1, 1],
                    padding='SAME',
                    gconv_indices=gconv_indices,
                    gconv_shape_info=gconv_shape_info)
    tf.add_to_collection('conv_output1', conv1)
    pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)

    # pool1 = layers.Dropout(0.25)(pool1)
    out_channels = 2
    gconv_indices, gconv_shape_info, w_shape = gconv2d_util(
        h_input='C4',
        h_output='C4',
        in_channels=2,
        out_channels=out_channels,
        ksize=5)
    w = tf.get_variable('w2', shape=w_shape)
    conv2 = gconv2d(input=conv1,
                    filter=w,
                    strides=[1, 1, 1, 1],
                    padding='SAME',
                    gconv_indices=gconv_indices,
                    gconv_shape_info=gconv_shape_info)
    conv2 = tf.reshape(conv2,
                       conv2.get_shape().as_list()[:3] + [4] + [out_channels])
    conv2 = tf.reduce_mean(conv2, axis=3)
    conv2 = tf.reshape(conv2, conv2.get_shape().as_list()[:3] + [out_channels])
    pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)

    pool2_flat = tf.layers.flatten(pool2)
    u = pool2_flat
    print(u.shape)
    u = tf.layers.dense(inputs=pool2_flat,
                        units=output_dim,
                        activation=tf.nn.relu,
                        trainable=trainable)
    tf.add_to_collection('conv_output2', conv2)
    return u
示例#3
0
def cnn_fn(x,
           output_dim,
           trainable=True,
           group=None,
           mnist=True,
           num_filters=64):
    """t
    Adapted from https://www.tensorflow.org/tutorials/layers
    """
    if not mnist:
        input_shape = [32, 32, 3]
    else:
        input_shape = [28, 28, 1]
    input_shape = x.shape[1:4]

    conv1 = tf.layers.conv2d(inputs=x,
                             filters=num_filters,
                             kernel_size=[5, 5],
                             padding="same",
                             activation=None,
                             trainable=trainable)
    tf.add_to_collection('conv_output1', conv1)
    pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)
    # pool1 = layers.Dropout(0.25)(pool1)# pool1 = tf.Print(pool1, [pool1], "Here's pooling: ")
    conv2 = tf.layers.conv2d(inputs=pool1,
                             filters=num_filters,
                             kernel_size=[5, 5],
                             padding="same",
                             activation=None,
                             trainable=trainable)
    tf.add_to_collection('conv_output2', conv2)

    pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)

    # pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 8])
    pool2_flat = tf.layers.flatten(pool2)
    u = pool2_flat
    u = tf.layers.dense(inputs=pool2_flat,
                        units=60,
                        activation=tf.nn.relu,
                        trainable=trainable)
    u = tf.layers.dense(inputs=u,
                        units=output_dim,
                        activation=None,
                        trainable=trainable)
    return u
示例#4
0
    def my_model_fn(features, labels, mode, params=None, config=None):
        """Estimator model function.

    Args:
      features: input features dictionary
      labels: ignored
      mode: a tf.estimator.ModeKeys
      params: something
      config: something

    Returns:
      something
    """
        del labels, config
        global_step = tf.train.get_global_step()
        if use_tpu:
            ctx = params["context"]
            num_hosts = ctx.num_hosts
            host_placement_fn = ctx.tpu_host_placement_function
            device_list = [
                host_placement_fn(host_id=t) for t in range(num_hosts)
            ]
            # TODO(ylc): Better estimation of replica cache size?
            replica_cache_size = 300 * 1000000  # 300M per replica
            # Worker 0 caches all the TPU binaries.
            worker0_mem = replica_cache_size * ctx.num_replicas
            devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
            var_placer = mtf.utils.BalancedVariablePlacer(
                device_list, devices_memeory_usage)
            mesh_devices = [""] * mesh_shape.size
            physical_shape = list(
                params["context"].device_assignment.topology.mesh_shape)
            logical_to_physical = _logical_to_physical(physical_shape,
                                                       mesh_shape)
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape,
                layout_rules,
                mesh_devices,
                ctx.device_assignment,
                logical_to_physical=logical_to_physical)
        else:
            var_placer = None
            mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, layout_rules, mesh_devices)

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh", var_placer)

        def _import_feature(key, allow_missing=False):
            """Import a feature from the features dictionary into a mtf.Tensor.

      Args:
        key: a string
        allow_missing: a boolean

      Returns:
        a mtf.Tensor with dtype int32 and shape [batch_dim, length_dim]
      """
            outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size)
            batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size)
            length_dim = mtf.Dimension("length", sequence_length)

            mtf_shape = mtf.Shape([outer_batch_dim, batch_dim, length_dim])
            if key not in features:
                if allow_missing:
                    return None
                else:
                    raise ValueError("feature not found %s - features %s = " %
                                     (key, features))
            tf.logging.info("Import feature %s: %s" % (key, features[key]))

            x = tf.to_int32(features[key])
            x = tf.reshape(
                x, [outer_batch_size, batch_size // outer_batch_size, -1])

            if not use_tpu:
                x = tf.Print(x, [x],
                             "import feature %s" % key,
                             summarize=1000,
                             first_n=1)
            return mtf.import_fully_replicated(mesh, x, mtf_shape, name=key)

        if mode == tf.estimator.ModeKeys.PREDICT:
            inputs = _import_feature("inputs")
            inputs = mtf.reshape(
                inputs,
                mtf.Shape([
                    mtf.Dimension("batch", batch_size),
                    mtf.Dimension("length", sequence_length)
                ]))
            if isinstance(transformer_model, transformer.Unitransformer):
                mtf_samples = transformer_model.sample_autoregressive(
                    inputs, variable_dtype=get_variable_dtype())
            elif isinstance(transformer_model, transformer.Bitransformer):
                mtf_samples = transformer_model.decode(
                    inputs, variable_dtype=get_variable_dtype())
            else:
                raise ValueError("unrecognized class")
            mtf_samples = mtf.anonymize(mtf_samples)
            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)
            outputs = lowering.export_to_tf_tensor(mtf_samples)
            predictions = {"outputs": outputs}
            return tpu_estimator.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.PREDICT,
                predictions=predictions,
                prediction_hooks=[mtf.MtfRestoreHook(lowering)])

        targets = _import_feature("targets")
        anon_targets = mtf.anonymize(targets)
        if model_type == "lm":
            _, length_dim = targets.shape
            inputs = mtf.shift(targets, offset=1, dim=length_dim, wrap=False)
        else:
            inputs = _import_feature("inputs")

        if mode == tf.estimator.ModeKeys.EVAL:
            if isinstance(transformer_model, transformer.Unitransformer):
                mtf_samples = transformer_model.sample_autoregressive(
                    inputs, variable_dtype=get_variable_dtype())
            elif isinstance(transformer_model, transformer.Bitransformer):
                mtf_samples = transformer_model.decode(
                    inputs, variable_dtype=get_variable_dtype())
            else:
                raise ValueError("unrecognized class")
            mtf_samples = mtf.anonymize(mtf_samples)
            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)
            outputs = lowering.export_to_tf_tensor(mtf_samples)
            labels = lowering.export_to_tf_tensor(anon_targets)
            restore_hook = mtf.MtfRestoreHook(lowering)

            # metric_names becomes locally scoped if we simply assign
            # ["padded_neg_log_perplexity"] to it conditioned on if it's None.
            local_metric_names = metric_names or ["token_accuracy"]

            def metric_fn(labels, outputs):
                return get_metric_fns(local_metric_names, labels, outputs)

            eval_metrics = (metric_fn, [labels, outputs])
            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.EVAL,
                # Unfortunately TPUEstimatorSpec requires us to provide a value for
                # loss when in EVAL mode. Since we are sampling or decoding from the
                # model, we don't have a loss to report.
                loss=tf.constant(0.),
                evaluation_hooks=[restore_hook],
                eval_metrics=eval_metrics)

        if isinstance(transformer_model, transformer.Unitransformer):
            position_kwargs = dict(
                sequence_id=_import_feature("targets_segmentation", True),
                position=_import_feature("targets_position", True),
            )
        elif isinstance(transformer_model, transformer.Bitransformer):
            position_kwargs = dict(
                encoder_sequence_id=_import_feature("inputs_segmentation",
                                                    True),
                decoder_sequence_id=_import_feature("targets_segmentation",
                                                    True),
                encoder_position=_import_feature("inputs_position", True),
                decoder_position=_import_feature("targets_position", True),
            )
        else:
            raise ValueError("unrecognized class")

        logits, loss = transformer_model.call_simple(
            inputs=inputs,
            targets=targets,
            compute_loss=True,
            mode=mode,
            variable_dtype=get_variable_dtype(),
            **position_kwargs)

        if use_tpu and logits is not None:
            logits = mtf.anonymize(logits)

        # TRAIN mode
        if mode == tf.estimator.ModeKeys.TRAIN:
            var_grads = mtf.gradients(
                [loss], [v.outputs[0] for v in graph.trainable_variables])
            optimizer = mtf.optimize.AdafactorOptimizer(
                learning_rate=learning_rate)
            update_ops = optimizer.apply_grads(var_grads,
                                               graph.trainable_variables)

        lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack)

        tf_loss = lowering.export_to_tf_tensor(loss)
        tf_loss = tf.to_float(tf_loss)
        if not use_tpu:
            tf_loss = tf.Print(tf_loss,
                               [tf_loss, tf.train.get_global_step()],
                               "step, tf_loss")

        if mode == tf.estimator.ModeKeys.TRAIN:
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            train_op = tf.group(tf_update_ops)

        with mtf.utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=checkpoints_to_keep,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                model_dir,
                save_steps=save_steps,
                saver=saver,
                listeners=[saver_listener])
            gin_config_saver_hook = gin.tf.GinConfigSaverHook(
                model_dir, summarize_config=True)

            if mode == tf.estimator.ModeKeys.TRAIN:
                if use_tpu:
                    return tpu_estimator.TPUEstimatorSpec(
                        mode=tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        training_hooks=[
                            restore_hook,
                            saver_hook,
                            gin_config_saver_hook,
                        ])
                else:
                    return tf.estimator.EstimatorSpec(
                        tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        training_chief_hooks=[
                            restore_hook,
                            saver_hook,
                            gin_config_saver_hook,
                        ])
示例#5
0
def bn(x,
       params=None,
       moments=None,
       backprop_through_moments=True,
       use_ema=False,
       is_training=True,
       ema_epsilon=.9):
  """Batch normalization.

  The usage should be as follows: If x is the support images, moments should be
  None so that they are computed from the support set examples. On the other
  hand, if x is the query images, the moments argument should be used in order
  to pass in the mean and var that were computed from the support set.

  Args:
    x: inputs.
    params: None or a dict containing the values of the offset and scale params.
    moments: None or a dict containing the values of the mean and var to use for
      batch normalization.
    backprop_through_moments: Whether to allow gradients to flow through the
      given support set moments. Only applies to non-transductive batch norm.
    use_ema: apply moving averages of batch norm statistics, or update them,
      depending on whether we are training or testing.  Note that passing
      moments will override this setting, and result in neither updating or
      using ema statistics.  This is important to make sure that episodic
      learners don't update ema statistics a second time when processing
      queries.
    is_training: if use_ema=True, this determines whether to apply the moving
      averages, or update them.
    ema_epsilon: if updating moving averages, use this value for the
      exponential moving averages.

  Returns:
    output: The result of applying batch normalization to the input.
    params: The updated params.
    moments: The updated moments.
  """
  params_keys, params_vars, moments_keys, moments_vars = [], [], [], []

  with tf.variable_scope('batch_norm'):
    scope_name = tf.get_variable_scope().name

    if use_ema:
      ema_shape = [1, 1, 1, x.get_shape().as_list()[-1]]
      mean_ema = tf.get_variable(
          'mean_ema',
          shape=ema_shape,
          initializer=tf.initializers.zeros(),
          trainable=False)
      var_ema = tf.get_variable(
          'var_ema',
          shape=ema_shape,
          initializer=tf.initializers.ones(),
          trainable=False)

    if moments is not None:
      if backprop_through_moments:
        mean = moments[scope_name + '/mean']
        var = moments[scope_name + '/var']
      else:
        # This variant does not yield good resutls.
        mean = tf.stop_gradient(moments[scope_name + '/mean'])
        var = tf.stop_gradient(moments[scope_name + '/var'])
    elif use_ema and not is_training:
      mean = mean_ema
      var = var_ema
    else:
      # If not provided, compute the mean and var of the current batch.

      replica_ctx = tf.distribute.get_replica_context()
      if replica_ctx:
        # from third_party/tensorflow/python/keras/layers/normalization_v2.py
        axes = list(range(len(x.shape) - 1))
        local_sum = tf.reduce_sum(x, axis=axes, keepdims=True)
        local_squared_sum = tf.reduce_sum(
            tf.square(x), axis=axes, keepdims=True)
        batch_size = tf.cast(tf.shape(x)[0], tf.float32)
        x_sum, x_squared_sum, global_batch_size = (
            replica_ctx.all_reduce('sum',
                                   [local_sum, local_squared_sum, batch_size]))

        axes_vals = [(tf.shape(x))[i] for i in range(1, len(axes))]
        multiplier = tf.cast(tf.reduce_prod(axes_vals), tf.float32)
        multiplier = multiplier * global_batch_size

        mean = x_sum / multiplier
        x_squared_mean = x_squared_sum / multiplier
        # var = E(x^2) - E(x)^2
        var = x_squared_mean - tf.square(mean)
      else:
        mean, var = tf.nn.moments(
            x, axes=list(range(len(x.shape) - 1)), keep_dims=True)

    # Only update ema's if training and we computed the moments in the current
    # call.  Note: at test time for episodic learners, ema's may be passed
    # from the support set to the query set, even if it's not really needed.
    if use_ema and is_training and moments is None:
      replica_ctx = tf.distribute.get_replica_context()
      mean_upd = tf.assign(mean_ema,
                           mean_ema * ema_epsilon + mean * (1.0 - ema_epsilon))
      var_upd = tf.assign(var_ema,
                          var_ema * ema_epsilon + var * (1.0 - ema_epsilon))
      updates = tf.group([mean_upd, var_upd])
      if replica_ctx:
        tf.add_to_collection(
            tf.GraphKeys.UPDATE_OPS,
            tf.cond(
                tf.equal(replica_ctx.replica_id_in_sync_group, 0),
                lambda: updates, tf.no_op))
      else:
        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, updates)

    moments_keys += [scope_name + '/mean']
    moments_vars += [mean]
    moments_keys += [scope_name + '/var']
    moments_vars += [var]

    if params is None:
      offset = tf.get_variable(
          'offset',
          shape=mean.get_shape().as_list(),
          initializer=tf.initializers.zeros())
      scale = tf.get_variable(
          'scale',
          shape=var.get_shape().as_list(),
          initializer=tf.initializers.ones())
    else:
      offset = params[scope_name + '/offset']
      scale = params[scope_name + '/scale']

    params_keys += [scope_name + '/offset']
    params_vars += [offset]
    params_keys += [scope_name + '/scale']
    params_vars += [scale]

    output = tf.nn.batch_normalization(x, mean, var, offset, scale, 0.00001)

    params = collections.OrderedDict(zip(params_keys, params_vars))
    moments = collections.OrderedDict(zip(moments_keys, moments_vars))

    return output, params, moments
示例#6
0
def cifar_fn(x, output_dim=10, trainable=True, group=None, mnist=True):
    """t
    Adapted from https://www.tensorflow.org/tutorials/layers
    """
    if not mnist:
        input_shape = [32, 32, 3]
    else:
        input_shape = [28, 28, 1]

    conv1 = tf.layers.conv2d(inputs=tf.reshape(x, [-1] + input_shape),
                             filters=16,
                             kernel_size=[3, 3],
                             padding="same",
                             activation=tf.nn.relu,
                             trainable=trainable)
    tf.add_to_collection('conv_output1', conv1)
    conv1 = tf.layers.conv2d(inputs=conv1,
                             filters=16,
                             kernel_size=[3, 3],
                             padding="same",
                             activation=tf.nn.relu,
                             trainable=trainable)
    pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)
    pool1 = layers.Dropout(0.25)(pool1)
    # pool1 = tf.Print(pool1, [pool1], "Here's pooling: ")
    conv2 = tf.layers.conv2d(inputs=pool1,
                             filters=32,
                             kernel_size=[3, 3],
                             padding="same",
                             activation=tf.nn.relu,
                             trainable=trainable)
    tf.add_to_collection('conv_output2', conv2)
    conv2 = tf.layers.conv2d(inputs=conv2,
                             filters=32,
                             kernel_size=[3, 3],
                             padding="same",
                             activation=tf.nn.relu,
                             trainable=trainable)
    pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)
    conv3 = tf.layers.conv2d(inputs=pool2,
                             filters=64,
                             kernel_size=[5, 5],
                             padding="same",
                             activation=tf.nn.relu,
                             trainable=trainable)

    conv3 = tf.layers.conv2d(inputs=conv3,
                             filters=64,
                             kernel_size=[7, 7],
                             padding="same",
                             activation=tf.nn.relu,
                             trainable=trainable)
    tf.add_to_collection('conv_output3', conv3)
    pool3 = tf.layers.max_pooling2d(inputs=conv3, pool_size=[2, 2], strides=2)
    # pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 8])

    pool2_flat = tf.layers.flatten(pool3)
    u = pool2_flat
    # u = tf.layers.dense(inputs=pool2_flat, units=256, activation=tf.nn.relu, trainable=trainable)
    u = tf.layers.dense(inputs=u,
                        units=output_dim,
                        activation=tf.nn.relu,
                        trainable=trainable)

    return u
示例#7
0
    def my_model_fn(features, labels, mode, params=None, config=None):
        """Estimator model function.

    Args:
      features: input features dictionary
      labels: ignored
      mode: a tf.estimator.ModeKeys
      params: something
      config: something

    Returns:
      something
    """
        del labels, config
        global_step = tf.train.get_global_step()
        if use_tpu:
            ctx = params["context"]
            num_hosts = ctx.num_hosts
            host_placement_fn = ctx.tpu_host_placement_function
            device_list = [
                host_placement_fn(host_id=t) for t in range(num_hosts)
            ]
            # TODO(ylc): Better estimation of replica cache size?
            replica_cache_size = 300 * 1000000  # 300M per replica
            # Worker 0 caches all the TPU binaries.
            worker0_mem = replica_cache_size * ctx.num_replicas
            devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
            var_placer = mtf.utils.BalancedVariablePlacer(
                device_list, devices_memeory_usage)
            mesh_devices = [""] * mesh_shape.size
            physical_shape = list(
                params["context"].device_assignment.topology.mesh_shape)
            logical_to_physical = _logical_to_physical(physical_shape,
                                                       mesh_shape)
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape,
                layout_rules,
                mesh_devices,
                ctx.device_assignment,
                logical_to_physical=logical_to_physical)
        else:
            var_placer = None
            mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, layout_rules, mesh_devices)

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh", var_placer)

        outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size)
        batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size)
        length_dim = mtf.Dimension("length", sequence_length)
        feature_shape = mtf.Shape([outer_batch_dim, batch_dim, length_dim])

        mtf_features = {}
        for key, x in features.items():
            x = tf.to_int32(features[key])
            x = tf.reshape(x, [
                outer_batch_size, batch_size // outer_batch_size,
                sequence_length
            ])
            if not use_tpu:
                x = tf.Print(x, [x],
                             "import feature %s" % key,
                             summarize=1000,
                             first_n=1)
            mtf_features[key] = mtf.import_fully_replicated(mesh,
                                                            x,
                                                            feature_shape,
                                                            name=key)

        if mode == tf.estimator.ModeKeys.PREDICT:
            inputs = mtf_features["inputs"]
            inputs = mtf.reshape(
                inputs,
                mtf.Shape([
                    mtf.Dimension("batch", batch_size),
                    mtf.Dimension("length", sequence_length)
                ]))
            if isinstance(transformer_model, transformer.Unitransformer):
                mtf_samples = transformer_model.sample_autoregressive(
                    inputs, variable_dtype=get_variable_dtype())
            elif isinstance(
                    transformer_model,
                (transformer.Bitransformer, transformer.StudentTeacher)):
                mtf_samples = transformer_model.decode(
                    inputs, variable_dtype=get_variable_dtype())
            else:
                raise ValueError("unrecognized class")
            mtf_samples = mtf.anonymize(mtf_samples)
            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)
            outputs = lowering.export_to_tf_tensor(mtf_samples)
            predictions = {"outputs": outputs}
            return tpu_estimator.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.PREDICT,
                predictions=predictions,
                prediction_hooks=[mtf.MtfRestoreHook(lowering)])

        elif mode == tf.estimator.ModeKeys.EVAL:
            raise NotImplementedError("We don't expect to use mode == eval.")

        else:
            assert mode == tf.estimator.ModeKeys.TRAIN
            num_microbatches = serialize_num_microbatches(
                batch_dim, length_dim, mesh_shape, layout_rules)

            def model_fn(mtf_features):
                """The kind of function we need for mtf.serialize_training_step.

        Args:
          mtf_features: a dictionary
        Returns:
          a dictionary
        """
                targets = mtf_features["targets"]
                if model_type == "lm":
                    _, _, length_dim = targets.shape
                    inputs = mtf.shift(targets,
                                       offset=1,
                                       dim=length_dim,
                                       wrap=False)
                else:
                    inputs = mtf_features["inputs"]

                if isinstance(transformer_model, transformer.Unitransformer):
                    position_kwargs = dict(
                        sequence_id=mtf_features.get("targets_segmentation",
                                                     None),
                        position=mtf_features.get("targets_position", None),
                    )
                elif isinstance(transformer_model, transformer.Bitransformer
                                ) or model_type == "bi_student_teacher":
                    position_kwargs = dict(
                        encoder_sequence_id=mtf_features.get(
                            "inputs_segmentation", None),
                        decoder_sequence_id=mtf_features.get(
                            "targets_segmentation", None),
                        encoder_position=mtf_features.get(
                            "inputs_position", None),
                        decoder_position=mtf_features.get(
                            "targets_position", None),
                    )
                else:
                    raise ValueError("unrecognized class")

                logits, loss = transformer_model.call_simple(
                    inputs=inputs,
                    targets=targets,
                    compute_loss=True,
                    mode=mode,
                    variable_dtype=get_variable_dtype(),
                    **position_kwargs)
                if num_microbatches > 1:
                    loss /= float(num_microbatches)
                del logits
                return {"loss": loss}

            if num_microbatches > 1:
                var_grads, loss_dict = mtf.serialize_training_step(
                    mtf_features, model_fn, batch_dim, num_microbatches)
            else:
                loss_dict = model_fn(mtf_features)
                var_grads = mtf.gradients(
                    [loss_dict["loss"]],
                    [v.outputs[0] for v in graph.trainable_variables])

            loss = loss_dict["loss"]

            if callable(learning_rate_schedule):
                # the following happens on CPU since TPU can't handle summaries.
                with mtf.utils.outside_all_rewrites():
                    learning_rate = learning_rate_schedule(
                        step=tf.train.get_global_step())
                    tf.summary.scalar("learning_rate", learning_rate)
            else:
                learning_rate = learning_rate_schedule

            update_ops = optimizer(learning_rate=learning_rate).apply_grads(
                var_grads, graph.trainable_variables)

            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)

            tf_loss = lowering.export_to_tf_tensor(loss)
            tf_loss = tf.to_float(tf_loss)
            if not use_tpu:
                tf_loss = tf.Print(
                    tf_loss, [tf_loss, tf.train.get_global_step()],
                    "step, tf_loss")

            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            train_op = tf.group(tf_update_ops)

            if hasattr(transformer_model, "initialize"):
                with mtf.utils.outside_all_rewrites():
                    transformer_model.initialize()

            with mtf.utils.outside_all_rewrites():
                # Copy master variables to slices. Must be called first.
                restore_hook = mtf.MtfRestoreHook(lowering)
                saver = tf.train.Saver(tf.global_variables(),
                                       sharded=True,
                                       max_to_keep=keep_checkpoint_max,
                                       keep_checkpoint_every_n_hours=2,
                                       defer_build=False,
                                       save_relative_paths=True)
                tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
                saver_listener = mtf.MtfCheckpointSaverListener(lowering)
                saver_hook = tf.train.CheckpointSaverHook(
                    model_dir,
                    save_steps=save_checkpoints_steps,
                    saver=saver,
                    listeners=[saver_listener])
                gin_config_saver_hook = gin.tf.GinConfigSaverHook(
                    model_dir, summarize_config=True)

                if use_tpu:
                    if tpu_summaries:
                        tf.summary.scalar("loss", tf_loss)
                        host_call = mtf.utils.create_host_call(model_dir)
                        mtf.utils.remove_summaries()
                    else:
                        host_call = None
                    return tpu_estimator.TPUEstimatorSpec(
                        mode=tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        host_call=host_call,
                        training_hooks=[
                            restore_hook,
                            saver_hook,
                            gin_config_saver_hook,
                        ])
                else:
                    return tf.estimator.EstimatorSpec(
                        tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        training_chief_hooks=[
                            restore_hook,
                            saver_hook,
                            gin_config_saver_hook,
                        ])