def my_model_fn(features, labels, mode, params=None, config=None): """Estimator model function. Args: features: input features dictionary labels: ignored mode: a tf.estimator.ModeKeys params: something config: something Returns: something """ del labels, config global_step = tf.train.get_global_step() if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [ host_placement_fn(host_id=t) for t in range(num_hosts) ] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer( device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size physical_shape = list( params["context"].device_assignment.topology.mesh_shape) logical_to_physical = _logical_to_physical(physical_shape, mesh_shape) mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment, logical_to_physical=logical_to_physical) else: var_placer = None mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) def _import_feature(key, allow_missing=False): """Import a feature from the features dictionary into a mtf.Tensor. Args: key: a string allow_missing: a boolean Returns: a mtf.Tensor with dtype int32 and shape [batch_dim, length_dim] """ outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size) batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size) length_dim = mtf.Dimension("length", sequence_length) mtf_shape = mtf.Shape([outer_batch_dim, batch_dim, length_dim]) if key not in features: if allow_missing: return None else: raise ValueError("feature not found %s - features %s = " % (key, features)) tf.logging.info("Import feature %s: %s" % (key, features[key])) x = tf.to_int32(features[key]) x = tf.reshape( x, [outer_batch_size, batch_size // outer_batch_size, -1]) if not use_tpu: x = tf.Print(x, [x], "import feature %s" % key, summarize=1000, first_n=1) return mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) if mode == tf.estimator.ModeKeys.PREDICT: inputs = _import_feature("inputs") inputs = mtf.reshape( inputs, mtf.Shape([ mtf.Dimension("batch", batch_size), mtf.Dimension("length", sequence_length) ])) if isinstance(transformer_model, transformer.Unitransformer): mtf_samples = transformer_model.sample_autoregressive( inputs, variable_dtype=get_variable_dtype()) elif isinstance(transformer_model, transformer.Bitransformer): mtf_samples = transformer_model.decode( inputs, variable_dtype=get_variable_dtype()) else: raise ValueError("unrecognized class") mtf_samples = mtf.anonymize(mtf_samples) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) outputs = lowering.export_to_tf_tensor(mtf_samples) predictions = {"outputs": outputs} return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) targets = _import_feature("targets") anon_targets = mtf.anonymize(targets) if model_type == "lm": _, length_dim = targets.shape inputs = mtf.shift(targets, offset=1, dim=length_dim, wrap=False) else: inputs = _import_feature("inputs") if mode == tf.estimator.ModeKeys.EVAL: if isinstance(transformer_model, transformer.Unitransformer): mtf_samples = transformer_model.sample_autoregressive( inputs, variable_dtype=get_variable_dtype()) elif isinstance(transformer_model, transformer.Bitransformer): mtf_samples = transformer_model.decode( inputs, variable_dtype=get_variable_dtype()) else: raise ValueError("unrecognized class") mtf_samples = mtf.anonymize(mtf_samples) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) outputs = lowering.export_to_tf_tensor(mtf_samples) labels = lowering.export_to_tf_tensor(anon_targets) restore_hook = mtf.MtfRestoreHook(lowering) # metric_names becomes locally scoped if we simply assign # ["padded_neg_log_perplexity"] to it conditioned on if it's None. local_metric_names = metric_names or ["token_accuracy"] def metric_fn(labels, outputs): return get_metric_fns(local_metric_names, labels, outputs) eval_metrics = (metric_fn, [labels, outputs]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, # Unfortunately TPUEstimatorSpec requires us to provide a value for # loss when in EVAL mode. Since we are sampling or decoding from the # model, we don't have a loss to report. loss=tf.constant(0.), evaluation_hooks=[restore_hook], eval_metrics=eval_metrics) if isinstance(transformer_model, transformer.Unitransformer): position_kwargs = dict( sequence_id=_import_feature("targets_segmentation", True), position=_import_feature("targets_position", True), ) elif isinstance(transformer_model, transformer.Bitransformer): position_kwargs = dict( encoder_sequence_id=_import_feature("inputs_segmentation", True), decoder_sequence_id=_import_feature("targets_segmentation", True), encoder_position=_import_feature("inputs_position", True), decoder_position=_import_feature("targets_position", True), ) else: raise ValueError("unrecognized class") logits, loss = transformer_model.call_simple( inputs=inputs, targets=targets, compute_loss=True, mode=mode, variable_dtype=get_variable_dtype(), **position_kwargs) if use_tpu and logits is not None: logits = mtf.anonymize(logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf.optimize.AdafactorOptimizer( learning_rate=learning_rate) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if not use_tpu: tf_loss = tf.Print(tf_loss, [tf_loss, tf.train.get_global_step()], "step, tf_loss") if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=checkpoints_to_keep, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( model_dir, save_steps=save_steps, saver=saver, listeners=[saver_listener]) gin_config_saver_hook = gin.tf.GinConfigSaverHook( model_dir, summarize_config=True) if mode == tf.estimator.ModeKeys.TRAIN: if use_tpu: return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[ restore_hook, saver_hook, gin_config_saver_hook, ]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[ restore_hook, saver_hook, gin_config_saver_hook, ])
def __init__(self, train_learner, eval_learner, is_training, dataset_list, checkpoint_dir, summary_dir, eval_finegrainedness, eval_finegrainedness_split, eval_imbalance_dataset, num_train_classes, num_test_classes, num_train_examples, num_test_examples, learn_config, learner_config, data_config): """Initializes a Trainer. Args: train_learner: A Learner to be used for meta-training. eval_learner: A Learner to be used for meta-validation or meta-testing. is_training: Bool, whether or not to train or just evaluate. dataset_list: A list of names of datasets to include in the benchmark. This can be any subset of the supported datasets. checkpoint_dir: A string, the path to the checkpoint directory, or None if no checkpointing should occur. summary_dir: A string, the path to the checkpoint directory, or None if no summaries should be saved. eval_finegrainedness: Whether to perform binary ImageNet evaluation for assessing the performance on fine- vs coarse- grained tasks. eval_finegrainedness_split: The subgraph of ImageNet to perform the aforementioned analysis on. Notably, if this is 'train', we need to ensure that an training data is used episodically, even if the given model is the baseline model which usually uses batches for training. eval_imbalance_dataset: A dataset on which to perform evaluation for assessing how class imbalance affects performance in binary episodes. By default it is empty and no imbalance analysis is performed. num_train_classes: An int or None, the number of classes in episodic meta-training. num_test_classes: An int or None, the number of classes in episodic meta-testing. num_train_examples: An int or None, the number of support examples. num_test_examples: An int or None, the number of query examples. learn_config: A LearnConfig, the learning configuration. learner_config: A LearnerConfig, the learner configuration. data_config: A DataConfig, the data configuration. """ self.train_learner_class = train_learner self.eval_learner_class = eval_learner self.is_training = is_training self.dataset_list = dataset_list self.checkpoint_dir = checkpoint_dir self.summary_dir = summary_dir self.eval_finegrainedness = eval_finegrainedness self.eval_finegrainedness_split = eval_finegrainedness_split self.eval_imbalance_dataset = eval_imbalance_dataset self.eval_split = 'test' if eval_finegrainedness: # The fine- vs coarse- grained evaluation may potentially be performed on # the training graph as it exhibits greater variety in this aspect. self.eval_split = eval_finegrainedness_split if eval_finegrainedness or eval_imbalance_dataset: # We restrict this analysis to the binary classification setting. tf.logging.info( 'Forcing the number of {} classes to be 2, since ' 'the finegrainedness analysis is applied on binary ' 'classification tasks only.'.format( eval_finegrainedness_split)) if eval_finegrainedness and eval_finegrainedness_split == 'train': num_train_classes = 2 else: num_test_classes = 2 self.num_train_classes = num_train_classes self.num_test_classes = num_test_classes self.num_train_examples = num_train_examples self.num_test_examples = num_test_examples msg = ('num_train_classes: {}, num_test_classes: {}, ' 'num_train_examples: {}, num_test_examples: {}').format( num_train_classes, num_test_classes, num_train_examples, num_test_examples) tf.logging.info(msg) self.learn_config = learn_config self.learner_config = learner_config if self.learn_config.transductive_batch_norm: tf.logging.warn('Using transductive batch norm!') # Only applicable for non-transudctive batch norm. The correct # implementation here involves computing the mean and variance based on the # support set and then using them to batch normalize the query set. During # meta-learning, we allow the gradients to flow through those moments. self.backprop_through_moments = True self.data_config = data_config # Get the image shape. self.image_shape = [data_config.image_height] * 2 + [3] # Create the benchmark specification. (self.benchmark_spec, self.valid_benchmark_spec) = self.get_benchmark_specification() if self.valid_benchmark_spec is None: # This means that ImageNet is not a dataset in the given benchmark spec. # In this case the validation will be carried out on randomly-sampled # episodes from the meta-validation sets of all given datasets. self.valid_benchmark_spec = self.benchmark_spec # Which splits to support depends on whether we are in the meta-training # phase or not. If we are, we need the train split, and the valid one for # early-stopping. If not, we only need the test split. if self.is_training: self.required_splits = ['train', 'valid'] else: self.required_splits = [self.eval_split] # Get the training, validation and testing specifications. # Each is either an EpisodeSpecification or a BatchSpecification. split_episode_or_batch_specs = {} if 'train' in self.required_splits: split_episode_or_batch_specs[ 'train'] = self._create_train_specification() for split in ['valid', 'test']: if split not in self.required_splits: continue split_episode_or_batch_specs[ split] = self._create_held_out_specification(split) self.split_episode_or_batch_specs = split_episode_or_batch_specs # Get the next data (episode or batch) for the different splits. self.next_data = {} for split in self.required_splits: self.next_data[split] = self.build_data(split) # Initialize the learners. self.ema_object = None # Using dummy EMA object for now. self.learners = {} self.embedding_fn = learner.NAME_TO_EMBEDDING_NETWORK[ self.learner_config.embedding_network] if 'train' in self.required_splits: self.learners['train'] = (self.create_train_learner( self.train_learner_class, self.get_next('train'))) if self.eval_learner_class is not None: for split in ['valid', 'test']: if split not in self.required_splits: continue self.learners[split] = self.create_eval_learner( self.eval_learner_class, self.get_next(split)) # Get the Tensors for the losses / accuracies of the different learners. self.losses = dict( zip(self.required_splits, [ self.learners[split].compute_loss() for split in self.required_splits ])) self.accs = dict( zip(self.required_splits, [ self.learners[split].compute_accuracy() for split in self.required_splits ])) # Set self.way, self.shots to Tensors for the way/shots of the next episode. self.set_way_shots_classes_logits_targets() # Get an optimizer and the operation for meta-training. self.train_op = None if self.is_training: global_step = tf.train.get_or_create_global_step() learning_rate = self.learner_config.learning_rate if self.learner_config.decay_learning_rate: learning_rate = tf.train.exponential_decay( self.learner_config.learning_rate, global_step, decay_steps=self.learner_config.decay_every, decay_rate=self.learner_config.decay_rate, staircase=True) tf.summary.scalar('learning_rate', learning_rate) self.optimizer = tf.train.AdamOptimizer(learning_rate) self.train_op = self.get_train_op(global_step) vars_to_restore = [] # Omit from reloading any variables that contains as a substring anything in # the following list. For example, those that track iterator state, as # iterator state is not saved. omit_substrings = FLAGS.omit_from_saving_and_reloading.split(',') tf.logging.info('Omitting from saving / reloading any variable that ' 'contains any of the following substrings: %s' % omit_substrings) for var in tf.global_variables(): if not any( [substring in var.name for substring in omit_substrings]): vars_to_restore.append(var) else: tf.logging.info('Omitting variable %s' % var.name) self.saver = tf.train.Saver(var_list=vars_to_restore, max_to_keep=500) if self.checkpoint_dir is not None: if not tf.gfile.Exists(self.checkpoint_dir): tf.gfile.MakeDirs(self.checkpoint_dir) # Initialize a Session. self.initialize_session() self.create_summary_writer()
def initialize_session(self): """Initializes a tf Session.""" if ENABLE_TF_OPTIMIZATIONS: self.sess = tf.Session() else: rewriter_config = rewriter_config_pb2.RewriterConfig( disable_model_pruning=True, constant_folding=rewriter_config_pb2.RewriterConfig.OFF, arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, remapping=rewriter_config_pb2.RewriterConfig.OFF, shape_optimization=rewriter_config_pb2.RewriterConfig.OFF, dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF, function_optimization=rewriter_config_pb2.RewriterConfig.OFF, layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF, loop_optimization=rewriter_config_pb2.RewriterConfig.OFF, memory_optimization=rewriter_config_pb2.RewriterConfig. NO_MEM_OPT) graph_options = tf.GraphOptions(rewrite_options=rewriter_config) session_config = tf.ConfigProto(graph_options=graph_options) self.sess = tf.Session(config=session_config) # Restore or initialize the variables. self.sess.run(tf.global_variables_initializer()) self.sess.run(tf.local_variables_initializer()) if self.learner_config.checkpoint_for_eval: # Requested a specific checkpoint. self.saver.restore(self.sess, self.learner_config.checkpoint_for_eval) tf.logging.info('Restored checkpoint: %s' % self.learner_config.checkpoint_for_eval) else: # Continue from the latest checkpoint if one exists. # This handles fault-tolerance. latest_checkpoint = None if self.checkpoint_dir is not None: latest_checkpoint = tf.train.latest_checkpoint( self.checkpoint_dir) if latest_checkpoint: self.saver.restore(self.sess, latest_checkpoint) tf.logging.info('Restored checkpoint: %s' % latest_checkpoint) else: tf.logging.info('No previous checkpoint.') self.sess.run(tf.global_variables_initializer()) self.sess.run(tf.local_variables_initializer()) # For episodic models, potentially use pretrained weights at the start of # training. If this happens it will overwrite the embedding weights, but # taking care to not restore the Adam parameters. if self.learner_config.pretrained_checkpoint and not self.sess.run( tf.train.get_global_step()): self.saver.restore(self.sess, self.learner_config.pretrained_checkpoint) tf.logging.info('Restored checkpoint: %s' % self.learner_config.pretrained_checkpoint) # We only want the embedding weights of the checkpoint we just restored. # So we re-initialize everything that's not an embedding weight. Also, # since this episodic finetuning procedure is a different optimization # problem than the original training of the baseline whose embedding # weights are re-used, we do not reload ADAM's variables and instead learn # them from scratch. vars_to_reinit, embedding_var_names, vars_to_reinit_names = [], [], [] for var in tf.global_variables(): if (any(keyword in var.name for keyword in EMBEDDING_KEYWORDS) and 'adam' not in var.name.lower()): embedding_var_names.append(var.name) continue vars_to_reinit.append(var) vars_to_reinit_names.append(var.name) tf.logging.info('Initializing all variables except for %s.' % embedding_var_names) self.sess.run(tf.variables_initializer(vars_to_reinit)) tf.logging.info('Re-initialized vars %s.' % vars_to_reinit_names)
def my_model_fn(features, labels, mode, params=None, config=None): """Estimator model function. Args: features: input features dictionary labels: ignored mode: a tf.estimator.ModeKeys params: something config: something Returns: something """ del labels, config global_step = tf.train.get_global_step() if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [ host_placement_fn(host_id=t) for t in range(num_hosts) ] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer( device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size physical_shape = list( params["context"].device_assignment.topology.mesh_shape) logical_to_physical = _logical_to_physical(physical_shape, mesh_shape) mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment, logical_to_physical=logical_to_physical) else: var_placer = None mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size) batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size) length_dim = mtf.Dimension("length", sequence_length) feature_shape = mtf.Shape([outer_batch_dim, batch_dim, length_dim]) mtf_features = {} for key, x in features.items(): x = tf.to_int32(features[key]) x = tf.reshape(x, [ outer_batch_size, batch_size // outer_batch_size, sequence_length ]) if not use_tpu: x = tf.Print(x, [x], "import feature %s" % key, summarize=1000, first_n=1) mtf_features[key] = mtf.import_fully_replicated(mesh, x, feature_shape, name=key) if mode == tf.estimator.ModeKeys.PREDICT: inputs = mtf_features["inputs"] inputs = mtf.reshape( inputs, mtf.Shape([ mtf.Dimension("batch", batch_size), mtf.Dimension("length", sequence_length) ])) if isinstance(transformer_model, transformer.Unitransformer): mtf_samples = transformer_model.sample_autoregressive( inputs, variable_dtype=get_variable_dtype()) elif isinstance( transformer_model, (transformer.Bitransformer, transformer.StudentTeacher)): mtf_samples = transformer_model.decode( inputs, variable_dtype=get_variable_dtype()) else: raise ValueError("unrecognized class") mtf_samples = mtf.anonymize(mtf_samples) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) outputs = lowering.export_to_tf_tensor(mtf_samples) predictions = {"outputs": outputs} return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) elif mode == tf.estimator.ModeKeys.EVAL: raise NotImplementedError("We don't expect to use mode == eval.") else: assert mode == tf.estimator.ModeKeys.TRAIN num_microbatches = serialize_num_microbatches( batch_dim, length_dim, mesh_shape, layout_rules) def model_fn(mtf_features): """The kind of function we need for mtf.serialize_training_step. Args: mtf_features: a dictionary Returns: a dictionary """ targets = mtf_features["targets"] if model_type == "lm": _, _, length_dim = targets.shape inputs = mtf.shift(targets, offset=1, dim=length_dim, wrap=False) else: inputs = mtf_features["inputs"] if isinstance(transformer_model, transformer.Unitransformer): position_kwargs = dict( sequence_id=mtf_features.get("targets_segmentation", None), position=mtf_features.get("targets_position", None), ) elif isinstance(transformer_model, transformer.Bitransformer ) or model_type == "bi_student_teacher": position_kwargs = dict( encoder_sequence_id=mtf_features.get( "inputs_segmentation", None), decoder_sequence_id=mtf_features.get( "targets_segmentation", None), encoder_position=mtf_features.get( "inputs_position", None), decoder_position=mtf_features.get( "targets_position", None), ) else: raise ValueError("unrecognized class") logits, loss = transformer_model.call_simple( inputs=inputs, targets=targets, compute_loss=True, mode=mode, variable_dtype=get_variable_dtype(), **position_kwargs) if num_microbatches > 1: loss /= float(num_microbatches) del logits return {"loss": loss} if num_microbatches > 1: var_grads, loss_dict = mtf.serialize_training_step( mtf_features, model_fn, batch_dim, num_microbatches) else: loss_dict = model_fn(mtf_features) var_grads = mtf.gradients( [loss_dict["loss"]], [v.outputs[0] for v in graph.trainable_variables]) loss = loss_dict["loss"] if callable(learning_rate_schedule): # the following happens on CPU since TPU can't handle summaries. with mtf.utils.outside_all_rewrites(): learning_rate = learning_rate_schedule( step=tf.train.get_global_step()) tf.summary.scalar("learning_rate", learning_rate) else: learning_rate = learning_rate_schedule update_ops = optimizer(learning_rate=learning_rate).apply_grads( var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if not use_tpu: tf_loss = tf.Print( tf_loss, [tf_loss, tf.train.get_global_step()], "step, tf_loss") tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) if hasattr(transformer_model, "initialize"): with mtf.utils.outside_all_rewrites(): transformer_model.initialize() with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=keep_checkpoint_max, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( model_dir, save_steps=save_checkpoints_steps, saver=saver, listeners=[saver_listener]) gin_config_saver_hook = gin.tf.GinConfigSaverHook( model_dir, summarize_config=True) if use_tpu: if tpu_summaries: tf.summary.scalar("loss", tf_loss) host_call = mtf.utils.create_host_call(model_dir) mtf.utils.remove_summaries() else: host_call = None return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, host_call=host_call, training_hooks=[ restore_hook, saver_hook, gin_config_saver_hook, ]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[ restore_hook, saver_hook, gin_config_saver_hook, ])