예제 #1
0
    def test_job_id(self, mock_serialize_assets, mock_submit_job):
        # TF 1.x is not supported
        if utils.is_tf_v1():
            with self.assertRaises(RuntimeError):
                client.cloud_fit(self._model,
                                 x=self._dataset,
                                 validation_data=self._dataset,
                                 remote_dir=self._remote_dir,
                                 job_spec=self._job_spec,
                                 batch_size=1,
                                 epochs=2,
                                 verbose=3)
            return

        test_job_id = 'test_job_id'
        client.cloud_fit(self._model,
                         x=self._dataset,
                         validation_data=self._dataset,
                         remote_dir=self._remote_dir,
                         job_spec=self._job_spec,
                         job_id=test_job_id,
                         batch_size=1,
                         epochs=2,
                         verbose=3)

        kargs, _ = mock_submit_job.call_args
        body, _ = kargs
        self.assertDictContainsSubset({
            'job_id': test_job_id,
        }, body)
예제 #2
0
    def test_serialize_assets(self):
        # TF 1.x is not supported
        if utils.is_tf_v1():
            with self.assertRaises(RuntimeError):
                client.cloud_fit(self._model,
                                 x=self._dataset,
                                 validation_data=self._dataset,
                                 remote_dir=self._remote_dir,
                                 job_spec=self._job_spec,
                                 batch_size=1,
                                 epochs=2,
                                 verbose=3)
            return
        tensorboard_callback = tf.keras.callbacks.TensorBoard(
            log_dir=self._remote_dir)
        args = self._scalar_fit_kwargs
        args['callbacks'] = [tensorboard_callback]

        client._serialize_assets(self._remote_dir, self._model, **args)
        self.assertGreaterEqual(
            len(
                tf.io.gfile.listdir(
                    os.path.join(self._remote_dir, 'training_assets'))), 1)
        self.assertGreaterEqual(
            len(tf.io.gfile.listdir(os.path.join(self._remote_dir, 'model'))),
            1)

        training_assets_graph = tf.saved_model.load(
            os.path.join(self._remote_dir, 'training_assets'))

        pickled_callbacks = tfds.as_numpy(training_assets_graph.callbacks_fn())
        unpickled_callbacks = cloudpickle.loads(pickled_callbacks)
        self.assertIsInstance(unpickled_callbacks[0],
                              tf.keras.callbacks.TensorBoard)
    def test_in_memory_data(self):
        # Create a folder under remote dir for this test's data
        tmp_folder = str(uuid.uuid4())
        remote_dir = os.path.join(self.remote_dir, tmp_folder)

        # Keep track of test folders created for final clean up
        self.test_folders.append(remote_dir)

        x = np.random.random((2, 3))
        y = np.random.randint(0, 2, (2, 2))

        # TF 1.x is not supported
        if utils.is_tf_v1():
            with self.assertRaises(RuntimeError):
                client.cloud_fit(self.model(),
                                 x=x,
                                 y=y,
                                 remote_dir=remote_dir,
                                 region=self.region,
                                 project_id=self.project_id,
                                 image_uri=self.image_uri,
                                 epochs=2)
            return

        job_id = client.cloud_fit(self.model(),
                                  x=x,
                                  y=y,
                                  remote_dir=remote_dir,
                                  region=self.region,
                                  project_id=self.project_id,
                                  image_uri=self.image_uri,
                                  job_id='cloud_fit_e2e_test_{}_{}'.format(
                                      BUILD_ID.replace('-', '_'),
                                      'test_in_memory_data'),
                                  epochs=2)

        # Wait for AIP Training job to finish
        job_name = 'projects/{}/jobs/{}'.format(self.project_id, job_id)

        # Configure AI Platform training job
        api_client = discovery.build('ml', 'v1')
        request = api_client.projects().jobs().get(name=job_name)
        response = request.execute()
        while response['state'] not in ('SUCCEEDED', 'FAILED'):
            time.sleep(POLLING_INTERVAL_IN_SECONDS)
            response = request.execute()
        self.assertEqual(response['state'], 'SUCCEEDED')
예제 #4
0
  def test_fit_kwargs(self, mock_submit_job):
    # TF 1.x is not supported
    if utils.is_tf_v1():
      with self.assertRaises(RuntimeError):
        client.cloud_fit(
            self._model,
            x=self._dataset,
            validation_data=self._dataset,
            remote_dir=self._remote_dir,
            job_spec=self._job_spec,
            batch_size=1,
            epochs=2,
            verbose=3)
      return
    job_id = client.cloud_fit(
        self._model,
        x=self._dataset,
        validation_data=self._dataset,
        remote_dir=self._remote_dir,
        region=self._region,
        project_id=self._project_id,
        image_uri=self._image_uri,
        batch_size=1,
        epochs=2,
        verbose=3)

    kargs, _ = mock_submit_job.call_args
    body, _ = kargs
    self.assertEqual(body['job_id'], job_id)
    remote_dir = body['trainingInput']['args'][1]

    training_assets_graph = tf.saved_model.load(
        os.path.join(remote_dir, 'training_assets'))
    elements = training_assets_graph.fit_kwargs_fn()
    self.assertDictContainsSubset(
        tfds.as_numpy(elements), {
            'batch_size': 1,
            'epochs': 2,
            'verbose': 3
        })
예제 #5
0
    def test_distribution_strategy(self, mock_serialize_assets,
                                   mock_submit_job):
        # TF 1.x is not supported
        if utils.is_tf_v1():
            with self.assertRaises(RuntimeError):
                client.cloud_fit(self._model,
                                 x=self._dataset,
                                 remote_dir=self._remote_dir)
            return

        client.cloud_fit(self._model,
                         x=self._dataset,
                         remote_dir=self._remote_dir)

        kargs, _ = mock_submit_job.call_args
        body, _ = kargs
        self.assertDictContainsSubset(
            {
                'args': [
                    '--remote_dir', self._remote_dir,
                    '--distribution_strategy',
                    MULTI_WORKER_MIRRORED_STRATEGY_NAME
                ],
            }, body['trainingInput'])

        client.cloud_fit(self._model,
                         x=self._dataset,
                         remote_dir=self._remote_dir,
                         distribution_strategy=MIRRORED_STRATEGY_NAME,
                         job_spec=self._job_spec)

        kargs, _ = mock_submit_job.call_args
        body, _ = kargs
        self.assertDictContainsSubset(
            {
                'args': [
                    '--remote_dir', self._remote_dir,
                    '--distribution_strategy', MIRRORED_STRATEGY_NAME
                ],
            }, body['trainingInput'])

        with self.assertRaises(ValueError):
            client.cloud_fit(self._model,
                             x=self._dataset,
                             remote_dir=self._remote_dir,
                             distribution_strategy='not_implemented_strategy',
                             job_spec=self._job_spec)
예제 #6
0
  def test_custom_job_spec(self, mock_submit_job):
    # TF 1.x is not supported
    if utils.is_tf_v1():
      with self.assertRaises(RuntimeError):
        client.cloud_fit(
            self._model,
            x=self._dataset,
            validation_data=self._dataset,
            remote_dir=self._remote_dir,
            job_spec=self._job_spec,
            batch_size=1,
            epochs=2,
            verbose=3)
      return

    client.cloud_fit(
        self._model,
        x=self._dataset,
        validation_data=self._dataset,
        remote_dir=self._remote_dir,
        job_spec=self._job_spec,
        batch_size=1,
        epochs=2,
        verbose=3)

    kargs, _ = mock_submit_job.call_args
    body, _ = kargs
    self.assertDictContainsSubset(
        {
            'masterConfig': {
                'imageUri': self._image_uri,
            },
            'args': [
                '--remote_dir', self._remote_dir, '--distribution_strategy',
                MULTI_WORKER_MIRRORED_STRATEGY_NAME
            ],
        }, body['trainingInput'])