示例#1
0
def main(epochs):
    Task.init(project_name="examples", task_name="fastai v2")

    path = untar_data(URLs.PETS)
    files = get_image_files(path / "images")

    dls = ImageDataLoaders.from_name_func(path,
                                          files,
                                          label_func,
                                          item_tfms=Resize(224),
                                          num_workers=0)
    dls.show_batch()
    learn = cnn_learner(dls, resnet34, metrics=error_rate)
    learn.fine_tune(epochs)
    learn.show_results()
示例#2
0
def create_dataloaders() -> DataLoaders:
    """
    Create the dataloaders for the cats vs. dogs dataset.
    """
    path = untar_data(URLs.PETS) / "images"
    dls = ImageDataLoaders.from_name_func(
        path,
        get_image_files(path),
        valid_pct=0.2,
        batch_size=8,
        seed=42,
        label_func=is_cat,
        item_tfms=Resize(224),
    )
    return dls
示例#3
0
from fastai.vision.all import untar_data, URLs, ImageDataLoaders, get_image_files, Resize, error_rate, resnet34, \
    cnn_learner

from labml import lab, experiment
from labml.utils.fastai import LabMLFastAICallback

path = untar_data(
    URLs.PETS,
    dest=lab.get_data_path(),
    fname=lab.get_data_path() / URLs.path(URLs.PETS).name) / 'images'


def is_cat(x):
    return x[0].isupper()


dls = ImageDataLoaders.from_name_func(path,
                                      get_image_files(path),
                                      valid_pct=0.2,
                                      seed=42,
                                      label_func=is_cat,
                                      item_tfms=Resize(224))
# Train the model ⚡
learn = cnn_learner(dls,
                    resnet34,
                    metrics=error_rate,
                    cbs=LabMLFastAICallback())

with experiment.record(name='pets', exp_conf=learn.labml_configs()):
    learn.fine_tune(5)
def remove_failed_images(path: Path) -> L:
    image_paths = V.get_image_files(path)
    failed_images = V.verify_images(image_paths)
    return failed_images.map(Path.unlink)