Example #1
0
    def __init__(self, cpu_model, tpu_name_or_address, strategy):
        super(models.Model, self).__init__(  # pylint: disable=bad-super-call
            inputs=cpu_model.inputs,
            outputs=cpu_model.outputs,
            name=cpu_model.name,
        )

        # Create a mapping from numpy arrays to infeed managers.
        # Note: uses a list of tuples instead of a map because numpy arrays are
        # not hashable.
        self._numpy_to_infeed_manager_list = []

        self.predict_function = None
        self.test_function = None
        self.train_function = None
        self._strategy = strategy

        self._tpu_name_or_address = tpu_name_or_address
        self._cpu_model = cpu_model
        self._tpu_model = None
        self._tpu_weights_initialized = False
        self._graph = ops.Graph()

        self._cluster_resolver = tpu_cluster_resolver.TPUClusterResolver(
            tpu_name_or_address)
        master = self._cluster_resolver.master()
        cluster_spec = self._cluster_resolver.cluster_spec()
        self._session = tf_session.Session(
            graph=self._graph,
            target=master,
            config=config_pb2.ConfigProto(isolate_session_state=True))

        # TODO(saeta): Confirm the lines below work in ClusterSpec propagation env.
        if cluster_spec:
            self._session.cluster_def.CopyFrom(cluster_spec.as_cluster_def())

        with self._graph.as_default():
            self._session.run(tpu.initialize_system())

        # If the input CPU model has already been compiled, compile our TPU model
        # immediately.
        if self._cpu_model.optimizer:
            self.compile(
                self._cpu_model.optimizer,
                self._cpu_model.loss,
                self._cpu_model.metrics,
                self._cpu_model.loss_weights,
                self._cpu_model.sample_weight_mode,
                self._cpu_model.weighted_metrics,
                self._cpu_model.target_tensors,
            )
Example #2
0
def TPUDistributionStrategy(tpu_cluster_resolver=None):  # pylint: disable=invalid-name
    """Construct a TPUDistributionStrategy."""
    from tensorflow.contrib.distribute.python import tpu_strategy  # pylint: disable=g-import-not-at-top
    # TODO -- remove this when TPUStrategy API is consistent (b/112705069)
    if tpu_cluster_resolver is None:
        tpu_cluster_resolver = tpu_cluster_resolver_lib.TPUClusterResolver('')

    args, _, _, _ = tf_inspect.getargspec(tpu_strategy.TPUStrategy.__init__)
    if len(args) == 3:
        logging.info('Detected new TPUStrategy API.')
        return tpu_strategy.TPUStrategy(tpu_cluster_resolver, steps_per_run=1)
    else:
        logging.info('Detected old TPUStrategy API.')
        strategy = tpu_strategy.TPUStrategy(num_cores_per_host=8)
        strategy._tpu_cluster_resolver = tpu_cluster_resolver

    return strategy
Example #3
0
    def __init__(self, cpu_model, tpu_name_or_address, strategy):
        super(models.Model, self).__init__(  # pylint: disable=bad-super-call
            inputs=cpu_model.inputs,
            outputs=cpu_model.outputs,
            name=cpu_model.name,
        )

        self.predict_function = None
        self.test_function = None
        self.train_function = None
        self._strategy = strategy

        self._tpu_name_or_address = tpu_name_or_address
        self._cpu_model = cpu_model
        self._tpu_model = None
        self._tpu_weights_initialized = False
        self._graph = ops.Graph()

        cluster_resolver = tpu_cluster_resolver.TPUClusterResolver(
            tpu_name_or_address)
        cluster_spec = cluster_resolver.cluster_spec()
        self._session = tf_session.Session(
            graph=self._graph,
            target=cluster_resolver.master(),
            config=config_pb2.ConfigProto(isolate_session_state=True))

        if cluster_spec:
            self._session.cluster_def.CopyFrom(cluster_spec.as_cluster_def())

        with self._graph.as_default():
            self._session.run(tpu.initialize_system())

        # If the input CPU model has already been compiled, compile our TPU model
        # immediately.
        if self._cpu_model.optimizer:
            self.compile(
                self._cpu_model.optimizer,
                self._cpu_model.loss,
                self._cpu_model.metrics,
                self._cpu_model.loss_weights,
                self._cpu_model.sample_weight_mode,
                self._cpu_model.weighted_metrics,
                self._cpu_model.target_tensors,
            )
def setup_tpu_session(tpu_name_or_address):
    """Initializes and returns a Keras/TF session connected the TPU `master`.

  Args:
    tpu_name_or_address: A string that is either the name of the Cloud TPU,
      the grpc address of the Cloud TPU, or (Googlers only) the BNS name of the
      Cloud TPU. If tpu_name_or_address is None, the TPUClusterResolver will
      examine the environment to determine a potential Cloud TPU to use.

  Returns:
    A `tf.Session`.
  """
    cluster_resolver = tpu_cluster_resolver.TPUClusterResolver(
        tpu_name_or_address)
    cluster_spec = cluster_resolver.cluster_spec()
    session = tf_session.Session(
        target=cluster_resolver.master(),
        config=config_pb2.ConfigProto(isolate_session_state=True))
    if cluster_spec:
        session.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
    K.set_session(session)
    K.get_session().run(tpu.initialize_system())
    return session
Example #5
0
def TPUDistributionStrategy(tpu_cluster_resolver=None, num_cores=None):  # pylint: disable=invalid-name
    """Construct a TPUDistributionStrategy."""
    from tensorflow.contrib.distribute.python import tpu_strategy  # pylint: disable=g-import-not-at-top
    # TODO(b/112705069): Remove this when TPUStrategy API is consistent.
    # We are including this for (a) backwards compatibility for open sourced
    # releases of TensorFlow and (b) to work around a circular dependency
    # where keras_support and tpu_strategy depends on each other. Once we release
    # a final version and remove support for the old API, this will be deleted.
    # (See bug above for more details)
    if tpu_cluster_resolver is None:
        tpu_cluster_resolver = tpu_cluster_resolver_lib.TPUClusterResolver('')

    args, _, _, _ = tf_inspect.getargspec(tpu_strategy.TPUStrategy.__init__)
    if len(args) == 4:
        logging.info('Detected new TPUStrategy API.')
        return tpu_strategy.TPUStrategy(tpu_cluster_resolver,
                                        steps_per_run=1,
                                        num_cores=num_cores)
    else:
        logging.info('Detected old TPUStrategy API.')
        strategy = tpu_strategy.TPUStrategy(num_cores_per_host=8)
        strategy._tpu_cluster_resolver = tpu_cluster_resolver

    return strategy