x = self.effnet.extract_features(image) x = F.adaptive_avg_pool2d(x, 1).reshape(batch_size, -1) outputs = self.out(self.dropout(x)) if targets is not None: loss = nn.CrossEntropyLoss()(outputs, targets) metrics = self.monitor_metrics(outputs, targets) return outputs, loss, metrics return outputs, None, None train_dataset = ImageDataset( image_paths=train_image_paths, targets=train_targets, resize=(255,255), augmentations=train_aug, ) valid_dataset = ImageDataset( image_paths=valid_image_paths, targets=valid_targets, resize=(255,255), augmentations=valid_aug, ) model = SnakeModel(num_classes=dfx.breed.nunique()) es = EarlyStopping( monitor="valid_loss", model_path="model.bin", patience=3, mode="min" ) model.fit(
valid_image_paths = glob.glob( os.path.join(INPUT_PATH, f"jpeg-{IMAGE_SIZE}x{IMAGE_SIZE}", "val", "**", "*.jpeg"), recursive=True, ) train_targets = [x.split("/")[-2] for x in train_image_paths] valid_targets = [x.split("/")[-2] for x in valid_image_paths] lbl_enc = preprocessing.LabelEncoder() train_targets = lbl_enc.fit_transform(train_targets) valid_targets = lbl_enc.transform(valid_targets) train_dataset = ImageDataset( image_paths=train_image_paths, targets=train_targets, resize=None, augmentations=train_aug, ) valid_dataset = ImageDataset( image_paths=valid_image_paths, targets=valid_targets, resize=None, augmentations=valid_aug, ) model = FlowerModel(num_classes=len(lbl_enc.classes_)) es = EarlyStopping( monitor="valid_loss", model_path=os.path.join(MODEL_PATH, MODEL_NAME + ".bin"),
recursive=True, ) test_image_paths = glob.glob( os.path.join(INPUT_PATH, f"jpeg-{IMAGE_SIZE}x{IMAGE_SIZE}", "test", "*.jpeg"), ) train_targets = [x.split("/")[-2] for x in train_image_paths] valid_targets = [x.split("/")[-2] for x in valid_image_paths] train_targets = [CLASSES[c] for c in train_targets] valid_targets = [CLASSES[c] for c in valid_targets] train_dataset = ImageDataset( image_paths=train_image_paths, targets=train_targets, augmentations=train_aug, ) valid_dataset = ImageDataset( image_paths=valid_image_paths, targets=valid_targets, augmentations=valid_aug, ) test_dataset = ImageDataset( image_paths=test_image_paths, targets=[0] * len(test_image_paths), augmentations=test_aug, )