def test_recommendation_demo(self):
        with _rt.patch_download_and_extract_data(self.dataset_dir):
            data_dir = recommendation_demo.download_data(self.download_dir)
        self.assertEqual(data_dir, self.dataset_dir)

        export_dir = os.path.join(self.test_tempdir, 'export')
        tflite_filename = os.path.join(export_dir, 'model.tflite')
        with patch_data_loader():
            recommendation_demo.run(data_dir, export_dir, epochs=1)

        self.assertTrue(tf.io.gfile.exists(tflite_filename))
        self.assertGreater(os.path.getsize(tflite_filename), 0)
Пример #2
0
def setup_testdata(instance):
    """Setup testdata under download_dir, and unzip data to dataset_dir."""
    if not hasattr(instance, 'test_tempdir'):
        instance.test_tempdir = tempfile.mkdtemp()
    instance.download_dir = os.path.join(instance.test_tempdir, 'download')

    # Copy zip file and unzip.
    os.makedirs(instance.download_dir, exist_ok=True)
    # Use existing copy of data, if exists; otherwise, download it.
    try:
        path = test_util.get_test_data_path('recommendation_movielens')
        zip_file = os.path.join(path, 'ml-1m.zip')
        shutil.copy(zip_file, instance.download_dir)
        with zipfile.ZipFile(zip_file, 'r') as zfile:
            zfile.extractall(
                instance.download_dir)  # Will generate at 'ml-1m'.
        instance.dataset_dir = os.path.join(instance.download_dir, 'ml-1m')
    except ValueError:
        instance.dataset_dir = recommendation_demo.download_data(
            instance.download_dir)