Example #1
0
    def my_model_fn(features, labels, mode, params=None, config=None):
        """Estimator model function.

    Args:
      features: input features dictionary
      labels: ignored
      mode: a tf.estimator.ModeKeys
      params: something
      config: something

    Returns:
      something
    """
        del labels, config
        global_step = tf.train.get_global_step()
        if use_tpu:
            ctx = params["context"]
            num_hosts = ctx.num_hosts
            host_placement_fn = ctx.tpu_host_placement_function
            device_list = [
                host_placement_fn(host_id=t) for t in range(num_hosts)
            ]
            # TODO(ylc): Better estimation of replica cache size?
            replica_cache_size = 300 * 1000000  # 300M per replica
            # Worker 0 caches all the TPU binaries.
            worker0_mem = replica_cache_size * ctx.num_replicas
            devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
            var_placer = mtf.utils.BalancedVariablePlacer(
                device_list, devices_memeory_usage)
            mesh_devices = [""] * mesh_shape.size
            physical_shape = list(
                params["context"].device_assignment.topology.mesh_shape)
            logical_to_physical = _logical_to_physical(physical_shape,
                                                       mesh_shape)
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape,
                layout_rules,
                mesh_devices,
                ctx.device_assignment,
                logical_to_physical=logical_to_physical)
        else:
            var_placer = None
            mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, layout_rules, mesh_devices)

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh", var_placer)

        def _import_feature(key, allow_missing=False):
            """Import a feature from the features dictionary into a mtf.Tensor.

      Args:
        key: a string
        allow_missing: a boolean

      Returns:
        a mtf.Tensor with dtype int32 and shape [batch_dim, length_dim]
      """
            outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size)
            batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size)
            length_dim = mtf.Dimension("length", sequence_length)

            mtf_shape = mtf.Shape([outer_batch_dim, batch_dim, length_dim])
            if key not in features:
                if allow_missing:
                    return None
                else:
                    raise ValueError("feature not found %s - features %s = " %
                                     (key, features))
            tf.logging.info("Import feature %s: %s" % (key, features[key]))

            x = tf.to_int32(features[key])
            x = tf.reshape(
                x, [outer_batch_size, batch_size // outer_batch_size, -1])

            if not use_tpu:
                x = tf.Print(x, [x],
                             "import feature %s" % key,
                             summarize=1000,
                             first_n=1)
            return mtf.import_fully_replicated(mesh, x, mtf_shape, name=key)

        if mode == tf.estimator.ModeKeys.PREDICT:
            inputs = _import_feature("inputs")
            inputs = mtf.reshape(
                inputs,
                mtf.Shape([
                    mtf.Dimension("batch", batch_size),
                    mtf.Dimension("length", sequence_length)
                ]))
            if isinstance(transformer_model, transformer.Unitransformer):
                mtf_samples = transformer_model.sample_autoregressive(
                    inputs, variable_dtype=get_variable_dtype())
            elif isinstance(transformer_model, transformer.Bitransformer):
                mtf_samples = transformer_model.decode(
                    inputs, variable_dtype=get_variable_dtype())
            else:
                raise ValueError("unrecognized class")
            mtf_samples = mtf.anonymize(mtf_samples)
            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)
            outputs = lowering.export_to_tf_tensor(mtf_samples)
            predictions = {"outputs": outputs}
            return tpu_estimator.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.PREDICT,
                predictions=predictions,
                prediction_hooks=[mtf.MtfRestoreHook(lowering)])

        targets = _import_feature("targets")
        anon_targets = mtf.anonymize(targets)
        if model_type == "lm":
            _, length_dim = targets.shape
            inputs = mtf.shift(targets, offset=1, dim=length_dim, wrap=False)
        else:
            inputs = _import_feature("inputs")

        if mode == tf.estimator.ModeKeys.EVAL:
            if isinstance(transformer_model, transformer.Unitransformer):
                mtf_samples = transformer_model.sample_autoregressive(
                    inputs, variable_dtype=get_variable_dtype())
            elif isinstance(transformer_model, transformer.Bitransformer):
                mtf_samples = transformer_model.decode(
                    inputs, variable_dtype=get_variable_dtype())
            else:
                raise ValueError("unrecognized class")
            mtf_samples = mtf.anonymize(mtf_samples)
            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)
            outputs = lowering.export_to_tf_tensor(mtf_samples)
            labels = lowering.export_to_tf_tensor(anon_targets)
            restore_hook = mtf.MtfRestoreHook(lowering)

            # metric_names becomes locally scoped if we simply assign
            # ["padded_neg_log_perplexity"] to it conditioned on if it's None.
            local_metric_names = metric_names or ["token_accuracy"]

            def metric_fn(labels, outputs):
                return get_metric_fns(local_metric_names, labels, outputs)

            eval_metrics = (metric_fn, [labels, outputs])
            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.EVAL,
                # Unfortunately TPUEstimatorSpec requires us to provide a value for
                # loss when in EVAL mode. Since we are sampling or decoding from the
                # model, we don't have a loss to report.
                loss=tf.constant(0.),
                evaluation_hooks=[restore_hook],
                eval_metrics=eval_metrics)

        if isinstance(transformer_model, transformer.Unitransformer):
            position_kwargs = dict(
                sequence_id=_import_feature("targets_segmentation", True),
                position=_import_feature("targets_position", True),
            )
        elif isinstance(transformer_model, transformer.Bitransformer):
            position_kwargs = dict(
                encoder_sequence_id=_import_feature("inputs_segmentation",
                                                    True),
                decoder_sequence_id=_import_feature("targets_segmentation",
                                                    True),
                encoder_position=_import_feature("inputs_position", True),
                decoder_position=_import_feature("targets_position", True),
            )
        else:
            raise ValueError("unrecognized class")

        logits, loss = transformer_model.call_simple(
            inputs=inputs,
            targets=targets,
            compute_loss=True,
            mode=mode,
            variable_dtype=get_variable_dtype(),
            **position_kwargs)

        if use_tpu and logits is not None:
            logits = mtf.anonymize(logits)

        # TRAIN mode
        if mode == tf.estimator.ModeKeys.TRAIN:
            var_grads = mtf.gradients(
                [loss], [v.outputs[0] for v in graph.trainable_variables])
            optimizer = mtf.optimize.AdafactorOptimizer(
                learning_rate=learning_rate)
            update_ops = optimizer.apply_grads(var_grads,
                                               graph.trainable_variables)

        lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack)

        tf_loss = lowering.export_to_tf_tensor(loss)
        tf_loss = tf.to_float(tf_loss)
        if not use_tpu:
            tf_loss = tf.Print(tf_loss,
                               [tf_loss, tf.train.get_global_step()],
                               "step, tf_loss")

        if mode == tf.estimator.ModeKeys.TRAIN:
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            train_op = tf.group(tf_update_ops)

        with mtf.utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=checkpoints_to_keep,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                model_dir,
                save_steps=save_steps,
                saver=saver,
                listeners=[saver_listener])
            gin_config_saver_hook = gin.tf.GinConfigSaverHook(
                model_dir, summarize_config=True)

            if mode == tf.estimator.ModeKeys.TRAIN:
                if use_tpu:
                    return tpu_estimator.TPUEstimatorSpec(
                        mode=tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        training_hooks=[
                            restore_hook,
                            saver_hook,
                            gin_config_saver_hook,
                        ])
                else:
                    return tf.estimator.EstimatorSpec(
                        tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        training_chief_hooks=[
                            restore_hook,
                            saver_hook,
                            gin_config_saver_hook,
                        ])
Example #2
0
    def __init__(self, train_learner, eval_learner, is_training, dataset_list,
                 checkpoint_dir, summary_dir, eval_finegrainedness,
                 eval_finegrainedness_split, eval_imbalance_dataset,
                 num_train_classes, num_test_classes, num_train_examples,
                 num_test_examples, learn_config, learner_config, data_config):
        """Initializes a Trainer.

    Args:
      train_learner: A Learner to be used for meta-training.
      eval_learner: A Learner to be used for meta-validation or meta-testing.
      is_training: Bool, whether or not to train or just evaluate.
      dataset_list: A list of names of datasets to include in the benchmark.
        This can be any subset of the supported datasets.
      checkpoint_dir: A string, the path to the checkpoint directory, or None if
        no checkpointing should occur.
      summary_dir: A string, the path to the checkpoint directory, or None if no
        summaries should be saved.
      eval_finegrainedness: Whether to perform binary ImageNet evaluation for
        assessing the performance on fine- vs coarse- grained tasks.
      eval_finegrainedness_split: The subgraph of ImageNet to perform the
        aforementioned analysis on. Notably, if this is 'train', we need to
        ensure that an training data is used episodically, even if the given
        model is the baseline model which usually uses batches for training.
      eval_imbalance_dataset: A dataset on which to perform evaluation for
        assessing how class imbalance affects performance in binary episodes. By
        default it is empty and no imbalance analysis is performed.
      num_train_classes: An int or None, the number of classes in episodic
        meta-training.
      num_test_classes: An int or None, the number of classes in episodic
        meta-testing.
      num_train_examples: An int or None, the number of support examples.
      num_test_examples: An int or None, the number of query examples.
      learn_config: A LearnConfig, the learning configuration.
      learner_config: A LearnerConfig, the learner configuration.
      data_config: A DataConfig, the data configuration.
    """
        self.train_learner_class = train_learner
        self.eval_learner_class = eval_learner
        self.is_training = is_training
        self.dataset_list = dataset_list
        self.checkpoint_dir = checkpoint_dir
        self.summary_dir = summary_dir
        self.eval_finegrainedness = eval_finegrainedness
        self.eval_finegrainedness_split = eval_finegrainedness_split
        self.eval_imbalance_dataset = eval_imbalance_dataset

        self.eval_split = 'test'
        if eval_finegrainedness:
            # The fine- vs coarse- grained evaluation may potentially be performed on
            # the training graph as it exhibits greater variety in this aspect.
            self.eval_split = eval_finegrainedness_split

        if eval_finegrainedness or eval_imbalance_dataset:
            # We restrict this analysis to the binary classification setting.
            tf.logging.info(
                'Forcing the number of {} classes to be 2, since '
                'the finegrainedness analysis is applied on binary '
                'classification tasks only.'.format(
                    eval_finegrainedness_split))
            if eval_finegrainedness and eval_finegrainedness_split == 'train':
                num_train_classes = 2
            else:
                num_test_classes = 2

        self.num_train_classes = num_train_classes
        self.num_test_classes = num_test_classes
        self.num_train_examples = num_train_examples
        self.num_test_examples = num_test_examples
        msg = ('num_train_classes: {}, num_test_classes: {}, '
               'num_train_examples: {}, num_test_examples: {}').format(
                   num_train_classes, num_test_classes, num_train_examples,
                   num_test_examples)
        tf.logging.info(msg)

        self.learn_config = learn_config
        self.learner_config = learner_config

        if self.learn_config.transductive_batch_norm:
            tf.logging.warn('Using transductive batch norm!')

        # Only applicable for non-transudctive batch norm. The correct
        # implementation here involves computing the mean and variance based on the
        # support set and then using them to batch normalize the query set. During
        # meta-learning, we allow the gradients to flow through those moments.
        self.backprop_through_moments = True

        self.data_config = data_config
        # Get the image shape.
        self.image_shape = [data_config.image_height] * 2 + [3]

        # Create the benchmark specification.
        (self.benchmark_spec,
         self.valid_benchmark_spec) = self.get_benchmark_specification()
        if self.valid_benchmark_spec is None:
            # This means that ImageNet is not a dataset in the given benchmark spec.
            # In this case the validation will be carried out on randomly-sampled
            # episodes from the meta-validation sets of all given datasets.
            self.valid_benchmark_spec = self.benchmark_spec

        # Which splits to support depends on whether we are in the meta-training
        # phase or not. If we are, we need the train split, and the valid one for
        # early-stopping. If not, we only need the test split.
        if self.is_training:
            self.required_splits = ['train', 'valid']
        else:
            self.required_splits = [self.eval_split]

        # Get the training, validation and testing specifications.
        # Each is either an EpisodeSpecification or a BatchSpecification.
        split_episode_or_batch_specs = {}
        if 'train' in self.required_splits:
            split_episode_or_batch_specs[
                'train'] = self._create_train_specification()
        for split in ['valid', 'test']:
            if split not in self.required_splits:
                continue
            split_episode_or_batch_specs[
                split] = self._create_held_out_specification(split)
        self.split_episode_or_batch_specs = split_episode_or_batch_specs

        # Get the next data (episode or batch) for the different splits.
        self.next_data = {}
        for split in self.required_splits:
            self.next_data[split] = self.build_data(split)

        # Initialize the learners.
        self.ema_object = None  # Using dummy EMA object for now.
        self.learners = {}
        self.embedding_fn = learner.NAME_TO_EMBEDDING_NETWORK[
            self.learner_config.embedding_network]
        if 'train' in self.required_splits:
            self.learners['train'] = (self.create_train_learner(
                self.train_learner_class, self.get_next('train')))
        if self.eval_learner_class is not None:
            for split in ['valid', 'test']:
                if split not in self.required_splits:
                    continue
                self.learners[split] = self.create_eval_learner(
                    self.eval_learner_class, self.get_next(split))

        # Get the Tensors for the losses / accuracies of the different learners.
        self.losses = dict(
            zip(self.required_splits, [
                self.learners[split].compute_loss()
                for split in self.required_splits
            ]))
        self.accs = dict(
            zip(self.required_splits, [
                self.learners[split].compute_accuracy()
                for split in self.required_splits
            ]))

        # Set self.way, self.shots to Tensors for the way/shots of the next episode.
        self.set_way_shots_classes_logits_targets()

        # Get an optimizer and the operation for meta-training.
        self.train_op = None
        if self.is_training:
            global_step = tf.train.get_or_create_global_step()
            learning_rate = self.learner_config.learning_rate
            if self.learner_config.decay_learning_rate:
                learning_rate = tf.train.exponential_decay(
                    self.learner_config.learning_rate,
                    global_step,
                    decay_steps=self.learner_config.decay_every,
                    decay_rate=self.learner_config.decay_rate,
                    staircase=True)
            tf.summary.scalar('learning_rate', learning_rate)
            self.optimizer = tf.train.AdamOptimizer(learning_rate)
            self.train_op = self.get_train_op(global_step)

        vars_to_restore = []
        # Omit from reloading any variables that contains as a substring anything in
        # the following list. For example, those that track iterator state, as
        # iterator state is not saved.
        omit_substrings = FLAGS.omit_from_saving_and_reloading.split(',')
        tf.logging.info('Omitting from saving / reloading any variable that '
                        'contains any of the following substrings: %s' %
                        omit_substrings)
        for var in tf.global_variables():
            if not any(
                [substring in var.name for substring in omit_substrings]):
                vars_to_restore.append(var)
            else:
                tf.logging.info('Omitting variable %s' % var.name)
        self.saver = tf.train.Saver(var_list=vars_to_restore, max_to_keep=500)

        if self.checkpoint_dir is not None:
            if not tf.gfile.Exists(self.checkpoint_dir):
                tf.gfile.MakeDirs(self.checkpoint_dir)

        # Initialize a Session.
        self.initialize_session()
        self.create_summary_writer()
Example #3
0
    def initialize_session(self):
        """Initializes a tf Session."""
        if ENABLE_TF_OPTIMIZATIONS:
            self.sess = tf.Session()
        else:
            rewriter_config = rewriter_config_pb2.RewriterConfig(
                disable_model_pruning=True,
                constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
                arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
                remapping=rewriter_config_pb2.RewriterConfig.OFF,
                shape_optimization=rewriter_config_pb2.RewriterConfig.OFF,
                dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
                function_optimization=rewriter_config_pb2.RewriterConfig.OFF,
                layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF,
                loop_optimization=rewriter_config_pb2.RewriterConfig.OFF,
                memory_optimization=rewriter_config_pb2.RewriterConfig.
                NO_MEM_OPT)
            graph_options = tf.GraphOptions(rewrite_options=rewriter_config)
            session_config = tf.ConfigProto(graph_options=graph_options)
            self.sess = tf.Session(config=session_config)

        # Restore or initialize the variables.
        self.sess.run(tf.global_variables_initializer())
        self.sess.run(tf.local_variables_initializer())
        if self.learner_config.checkpoint_for_eval:
            # Requested a specific checkpoint.
            self.saver.restore(self.sess,
                               self.learner_config.checkpoint_for_eval)
            tf.logging.info('Restored checkpoint: %s' %
                            self.learner_config.checkpoint_for_eval)
        else:
            # Continue from the latest checkpoint if one exists.
            # This handles fault-tolerance.
            latest_checkpoint = None
            if self.checkpoint_dir is not None:
                latest_checkpoint = tf.train.latest_checkpoint(
                    self.checkpoint_dir)
            if latest_checkpoint:
                self.saver.restore(self.sess, latest_checkpoint)
                tf.logging.info('Restored checkpoint: %s' % latest_checkpoint)
            else:
                tf.logging.info('No previous checkpoint.')
                self.sess.run(tf.global_variables_initializer())
                self.sess.run(tf.local_variables_initializer())

        # For episodic models, potentially use pretrained weights at the start of
        # training. If this happens it will overwrite the embedding weights, but
        # taking care to not restore the Adam parameters.
        if self.learner_config.pretrained_checkpoint and not self.sess.run(
                tf.train.get_global_step()):
            self.saver.restore(self.sess,
                               self.learner_config.pretrained_checkpoint)
            tf.logging.info('Restored checkpoint: %s' %
                            self.learner_config.pretrained_checkpoint)
            # We only want the embedding weights of the checkpoint we just restored.
            # So we re-initialize everything that's not an embedding weight. Also,
            # since this episodic finetuning procedure is a different optimization
            # problem than the original training of the baseline whose embedding
            # weights are re-used, we do not reload ADAM's variables and instead learn
            # them from scratch.
            vars_to_reinit, embedding_var_names, vars_to_reinit_names = [], [], []
            for var in tf.global_variables():
                if (any(keyword in var.name for keyword in EMBEDDING_KEYWORDS)
                        and 'adam' not in var.name.lower()):
                    embedding_var_names.append(var.name)
                    continue
                vars_to_reinit.append(var)
                vars_to_reinit_names.append(var.name)
            tf.logging.info('Initializing all variables except for %s.' %
                            embedding_var_names)
            self.sess.run(tf.variables_initializer(vars_to_reinit))
            tf.logging.info('Re-initialized vars %s.' % vars_to_reinit_names)
Example #4
0
    def my_model_fn(features, labels, mode, params=None, config=None):
        """Estimator model function.

    Args:
      features: input features dictionary
      labels: ignored
      mode: a tf.estimator.ModeKeys
      params: something
      config: something

    Returns:
      something
    """
        del labels, config
        global_step = tf.train.get_global_step()
        if use_tpu:
            ctx = params["context"]
            num_hosts = ctx.num_hosts
            host_placement_fn = ctx.tpu_host_placement_function
            device_list = [
                host_placement_fn(host_id=t) for t in range(num_hosts)
            ]
            # TODO(ylc): Better estimation of replica cache size?
            replica_cache_size = 300 * 1000000  # 300M per replica
            # Worker 0 caches all the TPU binaries.
            worker0_mem = replica_cache_size * ctx.num_replicas
            devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
            var_placer = mtf.utils.BalancedVariablePlacer(
                device_list, devices_memeory_usage)
            mesh_devices = [""] * mesh_shape.size
            physical_shape = list(
                params["context"].device_assignment.topology.mesh_shape)
            logical_to_physical = _logical_to_physical(physical_shape,
                                                       mesh_shape)
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape,
                layout_rules,
                mesh_devices,
                ctx.device_assignment,
                logical_to_physical=logical_to_physical)
        else:
            var_placer = None
            mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, layout_rules, mesh_devices)

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh", var_placer)

        outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size)
        batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size)
        length_dim = mtf.Dimension("length", sequence_length)
        feature_shape = mtf.Shape([outer_batch_dim, batch_dim, length_dim])

        mtf_features = {}
        for key, x in features.items():
            x = tf.to_int32(features[key])
            x = tf.reshape(x, [
                outer_batch_size, batch_size // outer_batch_size,
                sequence_length
            ])
            if not use_tpu:
                x = tf.Print(x, [x],
                             "import feature %s" % key,
                             summarize=1000,
                             first_n=1)
            mtf_features[key] = mtf.import_fully_replicated(mesh,
                                                            x,
                                                            feature_shape,
                                                            name=key)

        if mode == tf.estimator.ModeKeys.PREDICT:
            inputs = mtf_features["inputs"]
            inputs = mtf.reshape(
                inputs,
                mtf.Shape([
                    mtf.Dimension("batch", batch_size),
                    mtf.Dimension("length", sequence_length)
                ]))
            if isinstance(transformer_model, transformer.Unitransformer):
                mtf_samples = transformer_model.sample_autoregressive(
                    inputs, variable_dtype=get_variable_dtype())
            elif isinstance(
                    transformer_model,
                (transformer.Bitransformer, transformer.StudentTeacher)):
                mtf_samples = transformer_model.decode(
                    inputs, variable_dtype=get_variable_dtype())
            else:
                raise ValueError("unrecognized class")
            mtf_samples = mtf.anonymize(mtf_samples)
            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)
            outputs = lowering.export_to_tf_tensor(mtf_samples)
            predictions = {"outputs": outputs}
            return tpu_estimator.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.PREDICT,
                predictions=predictions,
                prediction_hooks=[mtf.MtfRestoreHook(lowering)])

        elif mode == tf.estimator.ModeKeys.EVAL:
            raise NotImplementedError("We don't expect to use mode == eval.")

        else:
            assert mode == tf.estimator.ModeKeys.TRAIN
            num_microbatches = serialize_num_microbatches(
                batch_dim, length_dim, mesh_shape, layout_rules)

            def model_fn(mtf_features):
                """The kind of function we need for mtf.serialize_training_step.

        Args:
          mtf_features: a dictionary
        Returns:
          a dictionary
        """
                targets = mtf_features["targets"]
                if model_type == "lm":
                    _, _, length_dim = targets.shape
                    inputs = mtf.shift(targets,
                                       offset=1,
                                       dim=length_dim,
                                       wrap=False)
                else:
                    inputs = mtf_features["inputs"]

                if isinstance(transformer_model, transformer.Unitransformer):
                    position_kwargs = dict(
                        sequence_id=mtf_features.get("targets_segmentation",
                                                     None),
                        position=mtf_features.get("targets_position", None),
                    )
                elif isinstance(transformer_model, transformer.Bitransformer
                                ) or model_type == "bi_student_teacher":
                    position_kwargs = dict(
                        encoder_sequence_id=mtf_features.get(
                            "inputs_segmentation", None),
                        decoder_sequence_id=mtf_features.get(
                            "targets_segmentation", None),
                        encoder_position=mtf_features.get(
                            "inputs_position", None),
                        decoder_position=mtf_features.get(
                            "targets_position", None),
                    )
                else:
                    raise ValueError("unrecognized class")

                logits, loss = transformer_model.call_simple(
                    inputs=inputs,
                    targets=targets,
                    compute_loss=True,
                    mode=mode,
                    variable_dtype=get_variable_dtype(),
                    **position_kwargs)
                if num_microbatches > 1:
                    loss /= float(num_microbatches)
                del logits
                return {"loss": loss}

            if num_microbatches > 1:
                var_grads, loss_dict = mtf.serialize_training_step(
                    mtf_features, model_fn, batch_dim, num_microbatches)
            else:
                loss_dict = model_fn(mtf_features)
                var_grads = mtf.gradients(
                    [loss_dict["loss"]],
                    [v.outputs[0] for v in graph.trainable_variables])

            loss = loss_dict["loss"]

            if callable(learning_rate_schedule):
                # the following happens on CPU since TPU can't handle summaries.
                with mtf.utils.outside_all_rewrites():
                    learning_rate = learning_rate_schedule(
                        step=tf.train.get_global_step())
                    tf.summary.scalar("learning_rate", learning_rate)
            else:
                learning_rate = learning_rate_schedule

            update_ops = optimizer(learning_rate=learning_rate).apply_grads(
                var_grads, graph.trainable_variables)

            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)

            tf_loss = lowering.export_to_tf_tensor(loss)
            tf_loss = tf.to_float(tf_loss)
            if not use_tpu:
                tf_loss = tf.Print(
                    tf_loss, [tf_loss, tf.train.get_global_step()],
                    "step, tf_loss")

            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            train_op = tf.group(tf_update_ops)

            if hasattr(transformer_model, "initialize"):
                with mtf.utils.outside_all_rewrites():
                    transformer_model.initialize()

            with mtf.utils.outside_all_rewrites():
                # Copy master variables to slices. Must be called first.
                restore_hook = mtf.MtfRestoreHook(lowering)
                saver = tf.train.Saver(tf.global_variables(),
                                       sharded=True,
                                       max_to_keep=keep_checkpoint_max,
                                       keep_checkpoint_every_n_hours=2,
                                       defer_build=False,
                                       save_relative_paths=True)
                tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
                saver_listener = mtf.MtfCheckpointSaverListener(lowering)
                saver_hook = tf.train.CheckpointSaverHook(
                    model_dir,
                    save_steps=save_checkpoints_steps,
                    saver=saver,
                    listeners=[saver_listener])
                gin_config_saver_hook = gin.tf.GinConfigSaverHook(
                    model_dir, summarize_config=True)

                if use_tpu:
                    if tpu_summaries:
                        tf.summary.scalar("loss", tf_loss)
                        host_call = mtf.utils.create_host_call(model_dir)
                        mtf.utils.remove_summaries()
                    else:
                        host_call = None
                    return tpu_estimator.TPUEstimatorSpec(
                        mode=tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        host_call=host_call,
                        training_hooks=[
                            restore_hook,
                            saver_hook,
                            gin_config_saver_hook,
                        ])
                else:
                    return tf.estimator.EstimatorSpec(
                        tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        training_chief_hooks=[
                            restore_hook,
                            saver_hook,
                            gin_config_saver_hook,
                        ])