def test_init(): model = ObjectDetector(num_classes=2) model.eval() batch_size = 2 ds = DummyDetectionDataset((3, 224, 224), 1, 2, 10) dl = DataLoader(ds, collate_fn=collate_fn, batch_size=batch_size) img, target = next(iter(dl)) out = model(img) assert len(out) == batch_size assert {"boxes", "labels", "scores"} <= out[0].keys()
def test_training(tmpdir, model): model = ObjectDetector(num_classes=2, model=model, pretrained=False, pretrained_backbone=False) ds = DummyDetectionDataset((3, 224, 224), 1, 2, 10) dl = DataLoader(ds, collate_fn=collate_fn) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.fit(model, dl)
def test_detection(tmpdir): train_folder, coco_ann_path = _create_synth_coco_dataset(tmpdir) data = ObjectDetectionData.from_coco(train_folder=train_folder, train_ann_file=coco_ann_path, batch_size=1) model = ObjectDetector(num_classes=data.num_classes) trainer = flash.Trainer(fast_dev_run=True) trainer.finetune(model, data) test_image_one = os.fspath(tmpdir / "test_one.png") test_image_two = os.fspath(tmpdir / "test_two.png") Image.new('RGB', (1920, 1080)).save(test_image_one) Image.new('RGB', (1920, 1080)).save(test_image_two) test_images = [test_image_one, test_image_two] model.predict(test_images)
# distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import flash from flash.core.data import download_data from flash.vision import ObjectDetectionData, ObjectDetector # 1. Download the data # Dataset Credit: https://www.kaggle.com/ultralytics/coco128 download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/") # 2. Load the Data datamodule = ObjectDetectionData.from_coco( train_folder="data/coco128/images/train2017/", train_ann_file="data/coco128/annotations/instances_train2017.json", batch_size=2 ) # 3. Build the model model = ObjectDetector(num_classes=datamodule.num_classes) # 4. Create the trainer. Run twice on data trainer = flash.Trainer(max_epochs=3) # 5. Finetune the model trainer.finetune(model, datamodule) # 6. Save it! trainer.save_checkpoint("object_detection_model.pt")
def test_training(tmpdir): model = ObjectDetector(num_classes=2, model="fasterrcnn_resnet50_fpn") ds = DummyDetectionDataset((3, 224, 224), 1, 2, 10) dl = DataLoader(ds, collate_fn=collate_fn) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.fit(model, dl)