def test_empty_input(self): export.train_and_export(epoch=1, dataset=self.mock_dataset, export_path="%s/model/1" % self.get_temp_dir()) model = hub.load("%s/model/1" % self.get_temp_dir()) output_ = model(tf.zeros([1, 28, 28, 1], dtype=tf.uint8).numpy()) self.assertEqual(output_.shape, [1, 10])
def test_model_exporting(self): export.train_and_export(epoch=1, dataset=self.mock_dataset, export_path="%s/model/1" % self.get_temp_dir()) self.assertTrue(os.listdir(self.get_temp_dir()))