def estimator_spec_predict(self, features, mesh, mesh_impl, use_tpu): mtf_samples = self.sample(features, mesh) lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl}) outputs = lowering.export_to_tf_tensor(mtf_samples) if self.has_input: ndims = len(outputs.shape.as_list()) actual_batch_size = tf.shape(features["inputs"])[0] outputs = tf.slice(outputs, [0] * ndims, [actual_batch_size] + [-1] * (ndims - 1)) predictions = { "outputs": outputs, "targets": features.get("infer_targets", features.get("inputs")), "inputs": features.get("inputs"), } if use_tpu: t2t_model.remove_summaries() return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[mtf.MtfRestoreHook(lowering)])
def _tpu_estimator_spec_eval(self, features, logits, labels, loss, losses_dict): """Construct EstimatorSpec for TPU EVAL mode.""" del losses_dict hparams = self.hparams if not hasattr(hparams, "problem"): raise NotImplementedError( "hparams is missing attribute `problem`. NasSeq2Seq must " "be used with a problem.") problem = hparams.problem t2t_model.remove_summaries() eval_metrics_fn = t2t_model.create_tpu_eval_metrics_fn(problem, hparams) if isinstance(logits, dict): # For TPU, logits dict will be passed as keyword arguments to # eval_metrics_fn. Here we add the labels to those arguments. logits.update({"labels": labels}) return contrib.tpu().TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, eval_metrics=(eval_metrics_fn, logits), loss=loss) else: return contrib.tpu().TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, eval_metrics=(eval_metrics_fn, [logits, labels]), loss=loss)
def estimator_spec_predict(self, features, mesh, mesh_impl, use_tpu): mtf_samples = mtf.anonymize(self.sample(features, mesh)) lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl}) outputs = lowering.export_to_tf_tensor(mtf_samples) if self.has_input: ndims = len(outputs.shape.as_list()) actual_batch_size = tf.shape(features["inputs"])[0] outputs = tf.slice( outputs, [0] * ndims, [actual_batch_size] + [-1] * (ndims - 1)) predictions = { "outputs": outputs } if features.get("infer_targets") is not None: predictions["infer_targets"] = features["infer_targets"] if features.get("inputs") is not None: predictions["inputs"] = features["inputs"] if use_tpu: t2t_model.remove_summaries() return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[mtf.MtfRestoreHook(lowering)])
def estimator_spec_eval(self, features, logits, labels, loss, losses_dict): """Constructs `tf.estimator.EstimatorSpec` for EVAL (evaluation) mode.""" del losses_dict def eval_metrics_fn(theorem_logits, theorem_labels, premise_logits, premise_labels): return dict(theorem_accuracy=accuracy(theorem_logits, theorem_labels), premise_accuracy=accuracy(premise_logits, premise_labels)) if t2t_model.common_layers.is_xla_compiled(): # Note: important to call this before remove_summaries() if self.hparams.tpu_enable_host_call: host_call = self.create_eval_host_call() else: host_call = None t2t_model.remove_summaries() return tf.contrib.tpu.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, eval_metrics=(eval_metrics_fn, logits), host_call=host_call, loss=loss) evaluation_hooks = [] # Create a SummarySaverHook eval_dir = os.path.join(self.hparams.model_dir, self.hparams.get('eval_dir_name', 'eval')) eval_summary_hook = tf.train.SummarySaverHook( save_steps=1, output_dir=eval_dir, summary_op=tf.summary.merge_all()) evaluation_hooks.append(eval_summary_hook) evaluation_hooks += self.hparams.problem.eval_hooks( features, logits, self.hparams) return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.EVAL, predictions=logits, eval_metric_ops=eval_metrics_fn(**logits), evaluation_hooks=evaluation_hooks, loss=loss)
def estimator_model_fn(cls, hparams, features, labels, mode, config=None, params=None, decode_hparams=None): hparams = copy.deepcopy(hparams) use_tpu = params and params.get("use_tpu", False) hparams.use_tpu = use_tpu # merge decode_hparams into hparams if present if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None: for k, v in six.iteritems(decode_hparams.values()): if hasattr(hparams, k) and getattr(hparams, k) != v: tf.logging.warning( "Overriding hparams.%s with %s from decode_hparams" % (k, v)) setattr(hparams, k, v) # Instantiate model data_parallelism = None if not use_tpu and config: data_parallelism = config.data_parallelism model = cls(hparams, mode, data_parallelism=data_parallelism, decode_hparams=decode_hparams) global_step = tf.train.get_global_step() mesh_shape = mtf.convert_to_shape(hparams.mesh_shape) layout_rules = mtf.convert_to_layout_rules(hparams.layout) if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [ host_placement_fn(host_id=t) for t in range(num_hosts) ] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer( device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) else: var_placer = None if len(data_parallelism.ps_devices) == 1: mesh_devices = [""] * mesh_shape.size else: assert len(data_parallelism.ps_devices) == mesh_shape.size mesh_devices = data_parallelism.ps_devices mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) # PREDICT mode if mode == tf.estimator.ModeKeys.PREDICT: return model.estimator_spec_predict(features, mesh, mesh_impl, use_tpu) logits, loss = model.mtf_model_fn(features, mesh) if use_tpu and logits is not None: logits = mtf.anonymize(logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) lr = learning_rate.learning_rate_schedule(hparams) tf.summary.scalar("learning_rate", lr) mtf_lr = mtf.import_tf_tensor( mesh, tf.convert_to_tensor(lr, dtype=tf.float32), mtf.Shape([])) optimizer = mtf.optimize.make_optimizer(hparams, mtf_lr) update_ops = [] for grad, var in zip(var_grads, graph.trainable_variables): update_ops.extend(optimizer.apply_grad(grad, var)) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if logits and mode != tf.estimator.ModeKeys.TRAIN: tf_logits = lowering.export_to_tf_tensor(logits) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) # tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( hparams.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) # EVAL mode if mode == tf.estimator.ModeKeys.EVAL: tf_logits = lowering.export_to_tf_tensor(logits) return model.estimator_spec_eval(features, tf_logits, labels, tf_loss, restore_hook, use_tpu) if use_tpu: # TPU host call. Important: need to be called before remove_summaries() if hparams.tpu_enable_host_call: host_call = t2t_model.create_host_call(hparams.model_dir) else: host_call = None t2t_model.remove_summaries() return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, host_call=host_call, training_hooks=[restore_hook, saver_hook]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook])
def estimator_spec_train(self, loss, num_async_replicas=1, use_tpu=False): """Constructs `tf.estimator.EstimatorSpec` for TRAIN (training) mode.""" train_op = self.optimize(loss, num_async_replicas=num_async_replicas, use_tpu=use_tpu) sparsity_technique = self._hparams.get("sparsity_technique") if "pruning" in sparsity_technique: if not self._hparams.load_masks_from: # If we are loading trained masks, don't add the mask update # step to the training process and keep the masks static with tf.control_dependencies([train_op]): mp_hparams = pruning_hparams( self._hparams, use_tpu, sparsity_technique == "random_pruning") p = magnitude_pruning.Pruning( mp_hparams, global_step=tf.train.get_global_step()) mask_update_op = p.conditional_mask_update_op() train_op = mask_update_op check_global_sparsity() if use_tpu: if self._hparams.warm_start_from: def scaffold_fn(): self.initialize_from_ckpt(self._hparams.warm_start_from) return tf.train.Scaffold() elif self._hparams.load_masks_from and self._hparams.load_weights_from: def scaffold_fn(): self.initialize_masks_from_ckpt( self._hparams.load_masks_from) self.initialize_non_masks_from_ckpt( self._hparams.load_weights_from) return tf.train.Scaffold() elif self._hparams.load_masks_from: def scaffold_fn(): self.initialize_masks_from_ckpt( self._hparams.load_masks_from) return tf.train.Scaffold() else: scaffold_fn = None # Note: important to call this before remove_summaries() if self.hparams.tpu_enable_host_call: host_call = t2t_model.create_host_call(self.hparams.model_dir) else: host_call = None t2t_model.remove_summaries() return contrib_tpu.TPUEstimatorSpec(tf_estimator.ModeKeys.TRAIN, loss=loss, train_op=train_op, host_call=host_call, scaffold_fn=scaffold_fn) else: if self._hparams.warm_start_from: self.initialize_from_ckpt(self._hparams.warm_start_from) elif self._hparams.load_masks_from: self.initialize_masks_from_ckpt(self._hparams.load_masks_from) return tf_estimator.EstimatorSpec(tf_estimator.ModeKeys.TRAIN, loss=loss, train_op=train_op)
def estimator_model_fn(cls, hparams, features, labels, mode, config=None, params=None, decode_hparams=None, use_tpu=False): hparams = copy.deepcopy(hparams) hparams.use_tpu = use_tpu # merge decode_hparams into hparams if present if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None: for k, v in six.iteritems(decode_hparams.values()): if hasattr(hparams, k) and getattr(hparams, k) != v: tf.logging.warning("Overriding hparams.%s with %s from decode_hparams" % (k, v)) setattr(hparams, k, v) # Instantiate model data_parallelism = None if not use_tpu and config: data_parallelism = config.data_parallelism model = cls( hparams, mode, data_parallelism=data_parallelism, decode_hparams=decode_hparams) global_step = tf.train.get_global_step() mesh_shape = mtf.convert_to_shape(hparams.mesh_shape) layout_rules = mtf.convert_to_layout_rules(hparams.layout) if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) else: var_placer = None if data_parallelism is None or len(data_parallelism.ps_devices) == 1: mesh_devices = [""] * mesh_shape.size else: assert len(data_parallelism.ps_devices) == mesh_shape.size mesh_devices = data_parallelism.ps_devices mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) # PREDICT mode if mode == tf.estimator.ModeKeys.PREDICT: return model.estimator_spec_predict(features, mesh, mesh_impl, use_tpu) logits, loss = model.mtf_model_fn(features, mesh) if use_tpu and logits is not None: logits = mtf.anonymize(logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) lr = learning_rate.learning_rate_schedule(hparams) tf.summary.scalar("learning_rate", lr) mtf_lr = mtf.import_tf_tensor( mesh, tf.convert_to_tensor(lr, dtype=tf.float32), mtf.Shape([])) optimizer = mtf.optimize.make_optimizer(hparams, mtf_lr) update_ops = [] for grad, var in zip(var_grads, graph.trainable_variables): update_ops.extend(optimizer.apply_grad(grad, var)) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if logits and mode != tf.estimator.ModeKeys.TRAIN: tf_logits = lowering.export_to_tf_tensor(logits) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) # tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver( tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( hparams.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) # EVAL mode if mode == tf.estimator.ModeKeys.EVAL: tf_logits = lowering.export_to_tf_tensor(logits) return model.estimator_spec_eval(features, tf_logits, labels, tf_loss, restore_hook, use_tpu) if use_tpu: # TPU host call. Important: need to be called before remove_summaries() if hparams.tpu_enable_host_call: host_call = t2t_model.create_host_call(hparams.model_dir) else: host_call = None t2t_model.remove_summaries() return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, host_call=host_call, training_hooks=[restore_hook, saver_hook]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook])