Пример #1
0
def hymenoptera_data_download(path: str, predict_size: int = 10):
    download_data("https://download.pytorch.org/tutorial/hymenoptera_data.zip", path)
    predict_folder = os.path.join(path, "hymenoptera_data/predict")
    if not os.path.exists(predict_folder):
        os.makedirs(predict_folder)
    if len(os.listdir(predict_folder)) > 0:
        return
    validation_image_paths = glob(os.path.join(path, "hymenoptera_data/val/*/*"))
    assert predict_size < len(validation_image_paths)
    indices = np.random.choice(range(len(validation_image_paths)), predict_size, replace=False)
    for index in indices:
        src = validation_image_paths[index]
        dst = os.path.join(predict_folder, src.split('/')[-1])
        shutil.copy(src, dst)
Пример #2
0
def titanic_data_download(path: str, predict_size: float = 0.1) -> None:
    if not os.path.exists(path):
        os.makedirs(path)

    path_data = os.path.join(path, "titanic.csv")
    download_data("https://pl-flash-data.s3.amazonaws.com/titanic.csv", path_data)

    if set(os.listdir(path)) != {"predict.csv", "titanic.csv"}:
        assert predict_size > 0 and predict_size < 1
        df = pd.read_csv(path_data)
        df_train, df_predict = train_test_split(df, test_size=predict_size)
        df_train.to_csv(path_data)
        df_predict = df_predict.drop(columns=["Survived"])
        df_predict.to_csv(os.path.join(path, "predict.csv"))
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from flash.text import TextClassificationData, TextClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/')

# 2. Load the data
datamodule = TextClassificationData.from_files(
    train_file="data/imdb/train.csv",
    valid_file="data/imdb/valid.csv",
    test_file="data/imdb/test.csv",
    input="review",
    target="sentiment",
    batch_size=512
)

# 3. Build the model
model = TextClassifier(num_classes=datamodule.num_classes)

# 4. Create the trainer. Run once on data
Пример #4
0
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

from flash.core.data import download_data
from flash.vision import ImageEmbedder

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip",
              'data/')

# 2. Create an ImageEmbedder with swav trained on imagenet.
# Check out SWAV: https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#swav
embedder = ImageEmbedder(backbone="swav-imagenet", embedding_dim=128)

# 3. Generate an embedding from an image path.
embeddings = embedder.predict(
    'data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg')

# 4. Print embeddings shape
print(embeddings.shape)

# 5. Create a tensor random image
random_image = torch.randn(1, 3, 32, 32)
Пример #5
0
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import flash
from flash.core.data import download_data
from flash.vision import ObjectDetectionData, ObjectDetector

# 1. Download the data
# Dataset Credit: https://www.kaggle.com/ultralytics/coco128
download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/")

# 2. Load the Data
datamodule = ObjectDetectionData.from_coco(
    train_folder="data/coco128/images/train2017/",
    train_ann_file="data/coco128/annotations/instances_train2017.json",
    batch_size=2
)

# 3. Build the model
model = ObjectDetector(num_classes=datamodule.num_classes)

# 4. Create the trainer. Run twice on data
trainer = flash.Trainer(max_epochs=3)

# 5. Finetune the model
Пример #6
0
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning import Trainer

from flash.core.data import download_data
from flash.text import SummarizationData, SummarizationTask

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", "data/")

# 2. Load the model from a checkpoint
model = SummarizationTask.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/summarization_model_xsum.pt")

# 2a. Summarize an article!
predictions = model.predict([
    """
    Camilla bought a box of mangoes with a Brixton £10 note, introduced last year to try to keep the money of local
    people within the community.The couple were surrounded by shoppers as they walked along Electric Avenue.
    They came to Brixton to see work which has started to revitalise the borough.
    It was Charles' first visit to the area since 1996, when he was accompanied by the former
    South African president Nelson Mandela.Greengrocer Derek Chong, who has run a stall on Electric Avenue
    for 20 years, said Camilla had been ""nice and pleasant"" when she purchased the fruit.
    ""She asked me what was nice, what would I recommend, and I said we've got some nice mangoes.
Пример #7
0
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import pytorch_lightning as pl
from torch import nn, optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

from flash import ClassificationTask
from flash.core.data import download_data

_PATH_ROOT = os.path.dirname(os.path.dirname(__file__))

# 1. Download the data
download_data("https://www.di.ens.fr/~lelarge/MNIST.tar.gz", os.path.join(_PATH_ROOT, 'data'))

# 2. Load a basic backbone
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28 * 28, 128),
    nn.ReLU(),
    nn.Linear(128, 10),
    nn.Softmax(),
)

# 3. Load a dataset
dataset = datasets.MNIST(os.path.join(_PATH_ROOT, 'data'), download=True, transform=transforms.ToTensor())

# 4. Split the data randomly
train, val, test = random_split(dataset, [50000, 5000, 5000])  # type: ignore
Пример #8
0
def test_download_data(tmpdir):
    path = os.path.join(tmpdir, "data")
    download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", path)
    assert set(os.listdir(path)) == {'titanic', 'titanic.zip'}
Пример #9
0
                        type=str,
                        default="data/wmt_en_ro/train.csv")
    parser.add_argument('--valid_file',
                        type=str,
                        default="data/wmt_en_ro/valid.csv")
    parser.add_argument('--test_file',
                        type=str,
                        default="data/wmt_en_ro/test.csv")
    parser.add_argument('--max_epochs', type=int, default=1)
    parser.add_argument('--learning_rate', type=float, default=1e-3)
    parser.add_argument('--gpus', type=int, default=None)
    args = parser.parse_args()

    # 1. Download the data
    if args.download:
        download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip",
                      "data/")

    # 2. Load the data
    datamodule = TranslationData.from_files(
        train_file="data/wmt_en_ro/train.csv",
        valid_file="data/wmt_en_ro/valid.csv",
        test_file="data/wmt_en_ro/test.csv",
        input="input",
        target="target",
    )

    # 3. Build the model
    model = TranslationTask(backbone=args.backbone)

    # 4. Create the trainer. Run once on data
    trainer = flash.Trainer(max_epochs=args.max_epochs,