def testDistributedModelFit(self, strategy):
        if not tf.__internal__.tf2.enabled() and isinstance(
                strategy, tf.distribute.experimental.ParameterServerStrategy):
            self.skipTest(
                "Parameter Server strategy with dataset creator need to be run "
                "when eager execution is enabled.")
        with strategy.scope():
            preprocessing_model = utils.make_preprocessing_model(
                self.get_temp_dir())
            training_model = utils.make_training_model()
            # Merge the two separate models into a single model for training.
            inputs = preprocessing_model.inputs
            outputs = training_model(preprocessing_model(inputs))
            merged_model = tf.keras.Model(inputs, outputs)
            merged_model.compile(optimizer="sgd", loss="binary_crossentropy")

        def dataset_fn(input_context):
            dataset = utils.make_dataset()
            dataset = dataset.shard(
                input_context.num_input_pipelines,
                input_context.input_pipeline_id,
            )
            batch_size = input_context.get_per_replica_batch_size(
                global_batch_size=utils.BATCH_SIZE)
            return dataset.batch(batch_size).repeat().prefetch(2)

        dataset_creator = tf.keras.utils.experimental.DatasetCreator(
            dataset_fn)
        merged_model.fit(dataset_creator,
                         epochs=2,
                         steps_per_epoch=utils.STEPS)
Esempio n. 2
0
    def testDistributedModelFit(self, strategy):
        with strategy.scope():
            preprocessing_model = utils.make_preprocessing_model(
                self.get_temp_dir())
            training_model = utils.make_training_model()
            training_model.compile(optimizer="sgd", loss="binary_crossentropy")

        dataset = utils.make_dataset()
        dataset = dataset.batch(utils.BATCH_SIZE)
        dataset = dataset.map(lambda x, y: (preprocessing_model(x), y))
        training_model.fit(dataset, epochs=2)
Esempio n. 3
0
  def testDistributedModelFit(self, strategy):
    with strategy.scope():
      preprocessing_model = utils.make_preprocessing_model(self.get_temp_dir())
      training_model = utils.make_training_model()
      training_model.compile(optimizer="sgd", loss="binary_crossentropy")

    def dataset_fn(input_context):
      dataset = utils.make_dataset()
      dataset = dataset.shard(input_context.num_input_pipelines,
                              input_context.input_pipeline_id)
      batch_size = input_context.get_per_replica_batch_size(
          global_batch_size=utils.BATCH_SIZE)
      dataset = dataset.batch(batch_size).repeat().prefetch(2)
      return dataset.map(lambda x, y: (preprocessing_model(x), y))

    dataset_creator = tf.keras.utils.experimental.DatasetCreator(dataset_fn)
    training_model.fit(dataset_creator, epochs=2, steps_per_epoch=utils.STEPS)