def setUp(self): self.path_to_frames = Path(__file__).parent / 'fixtures/' self.wildfire_path = Path(__file__).parent / 'fixtures/wildfire_dataset.csv' #self.wildfire_df = pd.read_csv(self.wildfire_path) self.wildfire = WildFireDataset(metadata=self.wildfire_path, path_to_frames=self.path_to_frames)
def test_wildfire_correctly_init_from_path(self): for path_to_frames in [self.path_to_frames, self.path_to_frames_str]: wildfire = WildFireDataset(metadata=self.wildfire_path, path_to_frames=path_to_frames) self.assertEqual(len(wildfire), 974) self.assertEqual(len(wildfire[3]), 2)
def test_wildfire_correctly_init_with_transform(self): wildfire = WildFireDataset(metadata=self.wildfire_path, path_to_frames=self.path_to_frames, transform=transforms.Compose([transforms.Resize((100, 66)), transforms.ToTensor()])) observation_3, metadata_3 = wildfire[3] self.assertEqual(observation_3.size(), torch.Size((3, 100, 66)))
def test_wildfire_correctly_init_with_multiple_targets(self): wildfire = WildFireDataset(metadata=self.wildfire_df, path_to_frames=self.path_to_frames, transform=transforms.ToTensor(), target_names=['fire', 'fire_id']) self.assertEqual(len(wildfire), 974) # try to get one image of wildfire (item 3 is authorized image fixture) observation_3, metadata_3 = wildfire[3] self.assertIsInstance(observation_3, torch.Tensor) # image correctly loaded ? self.assertEqual(observation_3.size(), torch.Size([3, 683, 910])) self.assertTrue(torch.equal(metadata_3, torch.tensor([0, 96]))) # metadata correctly loaded ?
def test_wildfire_correctly_init_from_dataframe(self): for path_to_frames in [self.path_to_frames, self.path_to_frames_str]: wildfire = WildFireDataset(metadata=self.wildfire_df, path_to_frames=path_to_frames) self.assertEqual(len(wildfire), 974) self.assertEqual(len(wildfire[3]), 2) # try to get one image of wildfire (item 3 is authorized image fixture) observation_3, metadata_3 = wildfire[3] self.assertIsInstance(observation_3, PIL.Image.Image) # image correctly loaded ? self.assertEqual(observation_3.size, (910, 683)) self.assertTrue(torch.equal(metadata_3, torch.tensor([0]))) # metadata correctly loaded ?
def test_dataloader_can_be_init_with_wildfire(self): wildfire = WildFireDataset(metadata=self.wildfire_path, path_to_frames=self.path_to_frames) DataLoader(wildfire, batch_size=64)
def test_invalid_csv_path_raises_exception(self): with self.assertRaises(ValueError): WildFireDataset(metadata='bad_path.csv', path_to_frames=self.path_to_frames)
def main(args): if args.deterministic: set_seed(42) # Set device if args.device is None: if torch.cuda.is_available(): args.device = 'cuda:0' else: args.device = 'cpu' #Create Dataloaders normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_transforms = transforms.Compose([ transforms.RandomResizedCrop(size=args.resize, scale=(0.8, 1.0)), transforms.RandomRotation(degrees=15), transforms.ColorJitter(), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ]) val_transforms = transforms.Compose([ transforms.Resize(size=args.resize), transforms.CenterCrop(size=args.resize), transforms.ToTensor(), normalize ]) tf = {'train': train_transforms, 'test': val_transforms, 'val': val_transforms} metadata = pd.read_csv(args.metadata) wildfire = WildFireDataset(metadata=metadata, path_to_frames=Path(args.DB), target_names=['fire']) ratios = {'train': args.ratio_train, 'val': args.ratio_val, 'test': args.ratio_test} splitter = WildFireSplitter(ratios, transforms=tf) splitter.fit(wildfire) train_loader = DataLoader(splitter.train, batch_size=args.batch_size, shuffle=True) val_loader = DataLoader(splitter.val, batch_size=args.batch_size, shuffle=True) # Model definition model = models.__dict__[args.model](imagenet_pretrained=args.pretrained, num_classes=args.nb_class, lin_features=args.lin_feats, concat_pool=args.concat_pool, bn_final=args.bn_final, dropout_prob=args.dropout_prob) # Freeze layers if not args.unfreeze: # Model is sequential for p in model[1].parameters(): p.requires_grad = False # Resume if args.resume: model.load_state_dict(torch.load(args.resume)['model']) # Send to device model.to(args.device) # Loss function criterion = nn.CrossEntropyLoss() # optimizer optimizer = optim.Adam(model.parameters(), betas=(0.9, 0.99), weight_decay=args.weight_decay) # Scheduler lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, epochs=args.epochs, steps_per_epoch=len(train_loader), cycle_momentum=(not isinstance(optimizer, optim.Adam)), div_factor=args.div_factor, final_div_factor=args.final_div_factor) best_loss = math.inf mb = master_bar(range(args.epochs)) for epoch_idx in mb: # Training train_loss = train_epoch(model, train_loader, optimizer, criterion, master_bar=mb, epoch=epoch_idx, scheduler=lr_scheduler, device=args.device) # Evaluation val_loss, acc = evaluate(model, val_loader, criterion, device=args.device) mb.comment = f"Epoch {epoch_idx+1}/{args.epochs}" mb.write(f"Epoch {epoch_idx+1}/{args.epochs} - Training loss: {train_loss:.4} | " f"Validation loss: {val_loss:.4} | Error rate: {1 - acc:.4}") # State saving if val_loss < best_loss: if args.output_dir: print(f"Validation loss decreased {best_loss:.4} --> {val_loss:.4}: saving state...") torch.save(dict(model=model.state_dict(), optimizer=optimizer.state_dict(), lr_scheduler=lr_scheduler.state_dict(), epoch=epoch_idx, args=args), Path(args.output_dir, f"{args.checkpoint}.pth")) best_loss = val_loss