def test_download_file_gdrive():
    files = datautils.download_file(name='requirements.txt',
                                    url='1_K6gqtZrvGWsINfkY0J6Evbd3bbWLg_0',
                                    cache_dir=BASE_PATH,
                                    extract=False)

    print(files)
    assert len(files) > 0
    for f in files:
        os.remove(f)
def test_download_archive_gdrive():
    files = datautils.download_file(name='test.tar.gz',
                                    url='1Sr1fm8PaaKQuvVpl34t6xyWty61o2psH',
                                    cache_dir=BASE_PATH)

    print(files)
    assert len(files) > 0
    for f in files:
        if f.parent.exists():
            shutil.rmtree(f.parent)
def test_download_file_github():
    files = datautils.download_file(
        name='densenet121_weights_tf_dim_ordering_tf_kernels.h5',
        url=
        'https://github.com/keras-team/keras-applications/releases/download/densenet/densenet121_weights_tf_dim_ordering_tf_kernels_notop.h5',
        cache_dir=BASE_PATH,
        extract=False)

    print(files)
    assert len(files) > 0
    shutil.rmtree(files[0].parent)
def test_download_archive():
    files = datautils.download_file(
        name='mnist.zip',
        url='https://github.com/mlampros/DataSets/raw/master/mnist.zip',
        cache_dir=BASE_PATH,
        extract=True)

    print(files)
    assert len(files) > 0
    shutil.rmtree(files[0].parent)
    shutil.rmtree(files[0].parent.parent)
def test_download_file_extract():
    with pytest.raises(AssertionError):
        datautils.download_file(name='requirements.txt',
                                url='1_K6gqtZrvGWsINfkY0J6Evbd3bbWLg_0',
                                cache_dir=BASE_PATH)