示例#1
0
def test_classification(tmpdir):
    tmpdir = Path(tmpdir)

    (tmpdir / "a").mkdir()
    (tmpdir / "b").mkdir()

    image_a = str(tmpdir / "a" / "a_1.png")
    image_b = str(tmpdir / "b" / "b_1.png")

    _rand_image().save(image_a)
    _rand_image().save(image_b)

    data = ImageClassificationData.from_files(
        train_files=[image_a, image_b],
        train_targets=[0, 1],
        num_workers=0,
        batch_size=2,
        image_size=(64, 64),
    )
    model = ImageClassifier(num_classes=2, backbone="resnet18")
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.finetune(model, datamodule=data, strategy="freeze")
        root: str = 'data/movie_posters') -> Tuple[List[str], List[List[int]]]:
    metadata = pd.read_csv(osp.join(root, data, "metadata.csv"))
    return ([
        osp.join(root, data, row['Id'] + ".jpg")
        for _, row in metadata.iterrows()
    ], [[int(row[genre]) for genre in genres]
        for _, row in metadata.iterrows()])


train_files, train_targets = load_data('train')
test_files, test_targets = load_data('test')

datamodule = ImageClassificationData.from_files(
    train_files=train_files,
    train_targets=train_targets,
    test_files=test_files,
    test_targets=test_targets,
    val_split=0.1,  # Use 10 % of the train dataset to generate validation one.
    image_size=(128, 128),
)

# 3. Build the model
model = ImageClassifier(
    backbone="resnet18",
    num_classes=len(genres),
    multi_label=True,
    metrics=F1(num_classes=len(genres)),
)

# 4. Create the trainer. Train on 2 gpus for 10 epochs.
trainer = flash.Trainer(max_epochs=1,
                        limit_train_batches=1,