Ejemplo n.º 1
0
 def estimator_spec_predict(self, features, mesh, mesh_impl, use_tpu):
     mtf_samples = self.sample(features, mesh)
     lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl})
     outputs = lowering.export_to_tf_tensor(mtf_samples)
     if self.has_input:
         ndims = len(outputs.shape.as_list())
         actual_batch_size = tf.shape(features["inputs"])[0]
         outputs = tf.slice(outputs, [0] * ndims,
                            [actual_batch_size] + [-1] * (ndims - 1))
     predictions = {
         "outputs": outputs,
         "targets": features.get("infer_targets", features.get("inputs")),
         "inputs": features.get("inputs"),
     }
     if use_tpu:
         t2t_model.remove_summaries()
         return tpu_estimator.TPUEstimatorSpec(
             mode=tf.estimator.ModeKeys.PREDICT,
             predictions=predictions,
             prediction_hooks=[mtf.MtfRestoreHook(lowering)])
     else:
         return tf.estimator.EstimatorSpec(
             tf.estimator.ModeKeys.PREDICT,
             predictions=predictions,
             prediction_hooks=[mtf.MtfRestoreHook(lowering)])
Ejemplo n.º 2
0
  def _tpu_estimator_spec_eval(self, features, logits, labels, loss,
                               losses_dict):
    """Construct EstimatorSpec for TPU EVAL mode."""
    del losses_dict
    hparams = self.hparams

    if not hasattr(hparams, "problem"):
      raise NotImplementedError(
          "hparams is missing attribute `problem`. NasSeq2Seq must "
          "be used with a problem.")

    problem = hparams.problem
    t2t_model.remove_summaries()
    eval_metrics_fn = t2t_model.create_tpu_eval_metrics_fn(problem, hparams)
    if isinstance(logits, dict):
      # For TPU, logits dict will be passed as keyword arguments to
      # eval_metrics_fn. Here we add the labels to those arguments.
      logits.update({"labels": labels})
      return contrib.tpu().TPUEstimatorSpec(
          tf.estimator.ModeKeys.EVAL,
          eval_metrics=(eval_metrics_fn, logits),
          loss=loss)
    else:
      return contrib.tpu().TPUEstimatorSpec(
          tf.estimator.ModeKeys.EVAL,
          eval_metrics=(eval_metrics_fn, [logits, labels]),
          loss=loss)
Ejemplo n.º 3
0
  def estimator_spec_predict(self, features, mesh, mesh_impl, use_tpu):
    mtf_samples = mtf.anonymize(self.sample(features, mesh))
    lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl})
    outputs = lowering.export_to_tf_tensor(mtf_samples)
    if self.has_input:
      ndims = len(outputs.shape.as_list())
      actual_batch_size = tf.shape(features["inputs"])[0]
      outputs = tf.slice(
          outputs, [0] * ndims, [actual_batch_size] + [-1] * (ndims - 1))
    predictions = {
        "outputs": outputs
    }
    if features.get("infer_targets") is not None:
      predictions["infer_targets"] = features["infer_targets"]

    if features.get("inputs") is not None:
      predictions["inputs"] = features["inputs"]

    if use_tpu:
      t2t_model.remove_summaries()
      return tpu_estimator.TPUEstimatorSpec(
          mode=tf.estimator.ModeKeys.PREDICT,
          predictions=predictions,
          prediction_hooks=[mtf.MtfRestoreHook(lowering)])
    else:
      return tf.estimator.EstimatorSpec(
          tf.estimator.ModeKeys.PREDICT,
          predictions=predictions,
          prediction_hooks=[mtf.MtfRestoreHook(lowering)])
Ejemplo n.º 4
0
    def estimator_spec_eval(self, features, logits, labels, loss, losses_dict):
        """Constructs `tf.estimator.EstimatorSpec` for EVAL (evaluation) mode."""
        del losses_dict

        def eval_metrics_fn(theorem_logits, theorem_labels, premise_logits,
                            premise_labels):
            return dict(theorem_accuracy=accuracy(theorem_logits,
                                                  theorem_labels),
                        premise_accuracy=accuracy(premise_logits,
                                                  premise_labels))

        if t2t_model.common_layers.is_xla_compiled():
            # Note: important to call this before remove_summaries()
            if self.hparams.tpu_enable_host_call:
                host_call = self.create_eval_host_call()
            else:
                host_call = None
            t2t_model.remove_summaries()
            return tf.contrib.tpu.TPUEstimatorSpec(
                tf.estimator.ModeKeys.EVAL,
                eval_metrics=(eval_metrics_fn, logits),
                host_call=host_call,
                loss=loss)

        evaluation_hooks = []
        # Create a SummarySaverHook
        eval_dir = os.path.join(self.hparams.model_dir,
                                self.hparams.get('eval_dir_name', 'eval'))
        eval_summary_hook = tf.train.SummarySaverHook(
            save_steps=1,
            output_dir=eval_dir,
            summary_op=tf.summary.merge_all())
        evaluation_hooks.append(eval_summary_hook)
        evaluation_hooks += self.hparams.problem.eval_hooks(
            features, logits, self.hparams)

        return tf.estimator.EstimatorSpec(
            tf.estimator.ModeKeys.EVAL,
            predictions=logits,
            eval_metric_ops=eval_metrics_fn(**logits),
            evaluation_hooks=evaluation_hooks,
            loss=loss)
Ejemplo n.º 5
0
    def estimator_model_fn(cls,
                           hparams,
                           features,
                           labels,
                           mode,
                           config=None,
                           params=None,
                           decode_hparams=None):
        hparams = copy.deepcopy(hparams)
        use_tpu = params and params.get("use_tpu", False)
        hparams.use_tpu = use_tpu
        # merge decode_hparams into hparams if present
        if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None:
            for k, v in six.iteritems(decode_hparams.values()):
                if hasattr(hparams, k) and getattr(hparams, k) != v:
                    tf.logging.warning(
                        "Overriding hparams.%s with %s from decode_hparams" %
                        (k, v))
                setattr(hparams, k, v)

        # Instantiate model
        data_parallelism = None
        if not use_tpu and config:
            data_parallelism = config.data_parallelism
        model = cls(hparams,
                    mode,
                    data_parallelism=data_parallelism,
                    decode_hparams=decode_hparams)

        global_step = tf.train.get_global_step()

        mesh_shape = mtf.convert_to_shape(hparams.mesh_shape)
        layout_rules = mtf.convert_to_layout_rules(hparams.layout)
        if use_tpu:
            ctx = params["context"]
            num_hosts = ctx.num_hosts
            host_placement_fn = ctx.tpu_host_placement_function
            device_list = [
                host_placement_fn(host_id=t) for t in range(num_hosts)
            ]
            # TODO(ylc): Better estimation of replica cache size?
            replica_cache_size = 300 * 1000000  # 300M per replica
            # Worker 0 caches all the TPU binaries.
            worker0_mem = replica_cache_size * ctx.num_replicas
            devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
            var_placer = mtf.utils.BalancedVariablePlacer(
                device_list, devices_memeory_usage)
            mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape, layout_rules, mesh_devices, ctx.device_assignment)
        else:
            var_placer = None
            if len(data_parallelism.ps_devices) == 1:
                mesh_devices = [""] * mesh_shape.size
            else:
                assert len(data_parallelism.ps_devices) == mesh_shape.size
                mesh_devices = data_parallelism.ps_devices
            mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, layout_rules, mesh_devices)

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh", var_placer)
        # PREDICT mode
        if mode == tf.estimator.ModeKeys.PREDICT:
            return model.estimator_spec_predict(features, mesh, mesh_impl,
                                                use_tpu)

        logits, loss = model.mtf_model_fn(features, mesh)
        if use_tpu and logits is not None:
            logits = mtf.anonymize(logits)

        # TRAIN mode
        if mode == tf.estimator.ModeKeys.TRAIN:
            var_grads = mtf.gradients(
                [loss], [v.outputs[0] for v in graph.trainable_variables])
            lr = learning_rate.learning_rate_schedule(hparams)
            tf.summary.scalar("learning_rate", lr)
            mtf_lr = mtf.import_tf_tensor(
                mesh, tf.convert_to_tensor(lr, dtype=tf.float32),
                mtf.Shape([]))
            optimizer = mtf.optimize.make_optimizer(hparams, mtf_lr)
            update_ops = []
            for grad, var in zip(var_grads, graph.trainable_variables):
                update_ops.extend(optimizer.apply_grad(grad, var))

        lowering = mtf.Lowering(graph, {mesh: mesh_impl})

        tf_loss = lowering.export_to_tf_tensor(loss)
        tf_loss = tf.to_float(tf_loss)
        if logits and mode != tf.estimator.ModeKeys.TRAIN:
            tf_logits = lowering.export_to_tf_tensor(logits)

        if mode == tf.estimator.ModeKeys.TRAIN:
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            # tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
            train_op = tf.group(tf_update_ops)

        with mtf.utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=10,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                hparams.model_dir,
                save_steps=1000,
                saver=saver,
                listeners=[saver_listener])

        # EVAL mode
        if mode == tf.estimator.ModeKeys.EVAL:
            tf_logits = lowering.export_to_tf_tensor(logits)
            return model.estimator_spec_eval(features, tf_logits, labels,
                                             tf_loss, restore_hook, use_tpu)

        if use_tpu:
            # TPU host call. Important: need to be called before remove_summaries()
            if hparams.tpu_enable_host_call:
                host_call = t2t_model.create_host_call(hparams.model_dir)
            else:
                host_call = None

            t2t_model.remove_summaries()
            return tpu_estimator.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                train_op=train_op,
                host_call=host_call,
                training_hooks=[restore_hook, saver_hook])
        else:
            return tf.estimator.EstimatorSpec(
                tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                train_op=train_op,
                training_chief_hooks=[restore_hook, saver_hook])
Ejemplo n.º 6
0
    def estimator_spec_train(self, loss, num_async_replicas=1, use_tpu=False):
        """Constructs `tf.estimator.EstimatorSpec` for TRAIN (training) mode."""
        train_op = self.optimize(loss,
                                 num_async_replicas=num_async_replicas,
                                 use_tpu=use_tpu)

        sparsity_technique = self._hparams.get("sparsity_technique")
        if "pruning" in sparsity_technique:
            if not self._hparams.load_masks_from:
                # If we are loading trained masks, don't add the mask update
                # step to the training process and keep the masks static
                with tf.control_dependencies([train_op]):
                    mp_hparams = pruning_hparams(
                        self._hparams, use_tpu,
                        sparsity_technique == "random_pruning")
                    p = magnitude_pruning.Pruning(
                        mp_hparams, global_step=tf.train.get_global_step())
                    mask_update_op = p.conditional_mask_update_op()
                    train_op = mask_update_op
            check_global_sparsity()

        if use_tpu:
            if self._hparams.warm_start_from:

                def scaffold_fn():
                    self.initialize_from_ckpt(self._hparams.warm_start_from)
                    return tf.train.Scaffold()
            elif self._hparams.load_masks_from and self._hparams.load_weights_from:

                def scaffold_fn():
                    self.initialize_masks_from_ckpt(
                        self._hparams.load_masks_from)
                    self.initialize_non_masks_from_ckpt(
                        self._hparams.load_weights_from)
                    return tf.train.Scaffold()
            elif self._hparams.load_masks_from:

                def scaffold_fn():
                    self.initialize_masks_from_ckpt(
                        self._hparams.load_masks_from)
                    return tf.train.Scaffold()
            else:
                scaffold_fn = None

            # Note: important to call this before remove_summaries()
            if self.hparams.tpu_enable_host_call:
                host_call = t2t_model.create_host_call(self.hparams.model_dir)
            else:
                host_call = None

            t2t_model.remove_summaries()

            return contrib_tpu.TPUEstimatorSpec(tf_estimator.ModeKeys.TRAIN,
                                                loss=loss,
                                                train_op=train_op,
                                                host_call=host_call,
                                                scaffold_fn=scaffold_fn)
        else:
            if self._hparams.warm_start_from:
                self.initialize_from_ckpt(self._hparams.warm_start_from)
            elif self._hparams.load_masks_from:
                self.initialize_masks_from_ckpt(self._hparams.load_masks_from)

            return tf_estimator.EstimatorSpec(tf_estimator.ModeKeys.TRAIN,
                                              loss=loss,
                                              train_op=train_op)
Ejemplo n.º 7
0
  def estimator_model_fn(cls,
                         hparams,
                         features,
                         labels,
                         mode,
                         config=None,
                         params=None,
                         decode_hparams=None,
                         use_tpu=False):
    hparams = copy.deepcopy(hparams)
    hparams.use_tpu = use_tpu
    # merge decode_hparams into hparams if present
    if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None:
      for k, v in six.iteritems(decode_hparams.values()):
        if hasattr(hparams, k) and getattr(hparams, k) != v:
          tf.logging.warning("Overriding hparams.%s with %s from decode_hparams"
                             % (k, v))
        setattr(hparams, k, v)

    # Instantiate model
    data_parallelism = None
    if not use_tpu and config:
      data_parallelism = config.data_parallelism
    model = cls(
        hparams,
        mode,
        data_parallelism=data_parallelism,
        decode_hparams=decode_hparams)

    global_step = tf.train.get_global_step()

    mesh_shape = mtf.convert_to_shape(hparams.mesh_shape)
    layout_rules = mtf.convert_to_layout_rules(hparams.layout)
    if use_tpu:
      ctx = params["context"]
      num_hosts = ctx.num_hosts
      host_placement_fn = ctx.tpu_host_placement_function
      device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
      # TODO(ylc): Better estimation of replica cache size?
      replica_cache_size = 300 * 1000000  # 300M per replica
      # Worker 0 caches all the TPU binaries.
      worker0_mem = replica_cache_size * ctx.num_replicas
      devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
      var_placer = mtf.utils.BalancedVariablePlacer(device_list,
                                                    devices_memeory_usage)
      mesh_devices = [""] * mesh_shape.size
      mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
          mesh_shape, layout_rules, mesh_devices, ctx.device_assignment)
    else:
      var_placer = None
      if data_parallelism is None or len(data_parallelism.ps_devices) == 1:
        mesh_devices = [""] * mesh_shape.size
      else:
        assert len(data_parallelism.ps_devices) == mesh_shape.size
        mesh_devices = data_parallelism.ps_devices
      mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
          mesh_shape, layout_rules, mesh_devices)

    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh", var_placer)
    # PREDICT mode
    if mode == tf.estimator.ModeKeys.PREDICT:
      return model.estimator_spec_predict(features, mesh, mesh_impl, use_tpu)

    logits, loss = model.mtf_model_fn(features, mesh)
    if use_tpu and logits is not None:
      logits = mtf.anonymize(logits)

    # TRAIN mode
    if mode == tf.estimator.ModeKeys.TRAIN:
      var_grads = mtf.gradients(
          [loss], [v.outputs[0] for v in graph.trainable_variables])
      lr = learning_rate.learning_rate_schedule(hparams)
      tf.summary.scalar("learning_rate", lr)
      mtf_lr = mtf.import_tf_tensor(
          mesh, tf.convert_to_tensor(lr, dtype=tf.float32), mtf.Shape([]))
      optimizer = mtf.optimize.make_optimizer(hparams, mtf_lr)
      update_ops = []
      for grad, var in zip(var_grads, graph.trainable_variables):
        update_ops.extend(optimizer.apply_grad(grad, var))

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})

    tf_loss = lowering.export_to_tf_tensor(loss)
    tf_loss = tf.to_float(tf_loss)
    if logits and mode != tf.estimator.ModeKeys.TRAIN:
      tf_logits = lowering.export_to_tf_tensor(logits)

    if mode == tf.estimator.ModeKeys.TRAIN:
      tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
      tf_update_ops.append(tf.assign_add(global_step, 1))
      # tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
      train_op = tf.group(tf_update_ops)

    with mtf.utils.outside_all_rewrites():
      # Copy master variables to slices. Must be called first.
      restore_hook = mtf.MtfRestoreHook(lowering)
      saver = tf.train.Saver(
          tf.global_variables(),
          sharded=True,
          max_to_keep=10,
          keep_checkpoint_every_n_hours=2,
          defer_build=False,
          save_relative_paths=True)
      tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
      saver_listener = mtf.MtfCheckpointSaverListener(lowering)
      saver_hook = tf.train.CheckpointSaverHook(
          hparams.model_dir,
          save_steps=1000,
          saver=saver,
          listeners=[saver_listener])

    # EVAL mode
    if mode == tf.estimator.ModeKeys.EVAL:
      tf_logits = lowering.export_to_tf_tensor(logits)
      return model.estimator_spec_eval(features, tf_logits, labels, tf_loss,
                                       restore_hook, use_tpu)

    if use_tpu:
      # TPU host call. Important: need to be called before remove_summaries()
      if hparams.tpu_enable_host_call:
        host_call = t2t_model.create_host_call(hparams.model_dir)
      else:
        host_call = None

      t2t_model.remove_summaries()
      return tpu_estimator.TPUEstimatorSpec(
          mode=tf.estimator.ModeKeys.TRAIN,
          loss=tf_loss,
          train_op=train_op,
          host_call=host_call,
          training_hooks=[restore_hook, saver_hook])
    else:
      return tf.estimator.EstimatorSpec(
          tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op,
          training_chief_hooks=[restore_hook, saver_hook])