Exemplo n.º 1
0
 def test_fetch_warm_starting_model(self):
     mock_metadata = tf.test.mock.Mock()
     artifacts = []
     for aid in [3, 2, 1]:
         model = standard_artifacts.Model()
         model.id = aid
         model.uri = 'uri-%d' % aid
         artifacts.append(model.artifact)
     mock_metadata.get_artifacts_by_type.return_value = artifacts
     trainer_driver = driver.Driver(mock_metadata)
     result = trainer_driver._fetch_latest_model()
     self.assertEqual('uri-3', result)
Exemplo n.º 2
0
 def test_fetch_warm_starting_model(self):
     mock_metadata = tf.test.mock.Mock()
     artifacts = []
     for span in [3, 2, 1]:
         model = types.Artifact(type_name='ModelExportPath')
         model.span = span
         model.uri = 'uri-%d' % span
         artifacts.append(model.artifact)
     mock_metadata.get_artifacts_by_type.return_value = artifacts
     trainer_driver = driver.Driver(mock_metadata)
     result = trainer_driver._fetch_latest_model()
     self.assertEqual('uri-3', result)
Exemplo n.º 3
0
 def test_fetch_warm_starting_model(self):
     mock_metadata = tf.test.mock.Mock()
     artifacts = []
     for span in [3, 2, 1]:
         model = types.TfxType(type_name='ModelExportPath')
         model.span = span
         model.uri = 'uri-%d' % span
         artifacts.append(model.artifact)
     mock_metadata.get_all_artifacts.return_value = artifacts
     output_data_dir = os.path.join(
         os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
         self._testMethodName)
     log_root = os.path.join(output_data_dir, 'log_root')
     trainer_driver = driver.Driver(log_root, mock_metadata)
     result = trainer_driver._fetch_latest_model()
     self.assertEqual('uri-3', result)