def setUp(self):
        self.path_to_frames = Path(__file__).parent / 'fixtures/'
        self.wildfire_path = Path(
            __file__).parent / 'fixtures/wildfire_dataset.csv'
        #self.wildfire_df = pd.read_csv(self.wildfire_path)

        self.wildfire = WildFireDataset(metadata=self.wildfire_path,
                                        path_to_frames=self.path_to_frames)
    def test_wildfire_correctly_init_with_transform(self):
        wildfire = WildFireDataset(metadata=self.wildfire_path,
                                   path_to_frames=self.path_to_frames,
                                   transform=transforms.Compose([
                                       transforms.Resize((100, 66)),
                                       transforms.ToTensor()
                                   ]))

        observation_3, metadata_3 = wildfire[3]
        self.assertEqual(observation_3.size(), torch.Size((3, 100, 66)))
    def test_wildfire_correctly_init_from_dataframe(self):
        wildfire = WildFireDataset(metadata=self.wildfire_df,
                                   path_to_frames=self.path_to_frames)

        self.assertEqual(len(wildfire), 974)

        # try to get one image of wildfire (item 3 is authorized image fixture)
        observation_3, metadata_3 = wildfire[3]
        self.assertIsInstance(observation_3,
                              PIL.Image.Image)  # image correctly loaded ?
        self.assertEqual(observation_3.size, (910, 683))
        self.assertTrue(torch.equal(metadata_3, torch.tensor(
            [0])))  # metadata correctly loaded ?
    def test_wildfire_correctly_init_with_multiple_targets(self):
        wildfire = WildFireDataset(metadata=self.wildfire_df,
                                   path_to_frames=self.path_to_frames,
                                   transform=transforms.ToTensor(),
                                   target_names=['fire', 'fire_id'])

        self.assertEqual(len(wildfire), 974)

        # try to get one image of wildfire (item 3 is authorized image fixture)
        observation_3, metadata_3 = wildfire[3]
        self.assertIsInstance(observation_3,
                              torch.Tensor)  # image correctly loaded ?
        self.assertEqual(observation_3.size(), torch.Size([3, 683, 910]))
        self.assertTrue(torch.equal(metadata_3, torch.tensor(
            [0, 96])))  # metadata correctly loaded ?
 def test_dataloader_can_be_init_with_wildfire(self):
     wildfire = WildFireDataset(metadata=self.wildfire_path,
                                path_to_frames=self.path_to_frames)
     DataLoader(wildfire, batch_size=64)
 def test_invalid_csv_path_raises_exception(self):
     with self.assertRaises(ValueError):
         WildFireDataset(metadata='bad_path.csv',
                         path_to_frames=self.path_to_frames)
    def test_wildfire_correctly_init_from_path(self):
        wildfire = WildFireDataset(metadata=self.wildfire_path,
                                   path_to_frames=self.path_to_frames)

        self.assertEqual(len(wildfire), 974)
Example #8
0
def main(args):

    if args.deterministic:
        set_seed(42)

    # Set device
    if args.device is None:
        if torch.cuda.is_available():
            args.device = 'cuda:0'
        else:
            args.device = 'cpu'

    #Create Dataloaders

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_transforms = transforms.Compose([
        transforms.RandomResizedCrop(size=args.resize, scale=(0.8, 1.0)),
        transforms.RandomRotation(degrees=15),
        transforms.ColorJitter(),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), normalize
    ])

    val_transforms = transforms.Compose([
        transforms.Resize(size=args.resize),
        transforms.CenterCrop(size=args.resize),
        transforms.ToTensor(), normalize
    ])

    tf = {
        'train': train_transforms,
        'test': val_transforms,
        'val': val_transforms
    }

    metadata = pd.read_csv(args.metadata)

    wildfire = WildFireDataset(metadata=metadata,
                               path_to_frames=Path(args.DB),
                               target_names=['fire'])

    ratios = {
        'train': args.ratio_train,
        'val': args.ratio_val,
        'test': args.ratio_test
    }

    splitter = WildFireSplitter(ratios, transforms=tf)
    splitter.fit(wildfire)

    train_loader = DataLoader(splitter.train,
                              batch_size=args.batch_size,
                              shuffle=True)
    val_loader = DataLoader(splitter.val,
                            batch_size=args.batch_size,
                            shuffle=True)

    # Model definition
    model = models.__dict__[args.model](imagenet_pretrained=args.pretrained,
                                        num_classes=args.nb_class,
                                        lin_features=args.lin_feats,
                                        concat_pool=args.concat_pool,
                                        bn_final=args.bn_final,
                                        dropout_prob=args.dropout_prob)

    # Freeze layers
    if not args.unfreeze:
        # Model is sequential
        for p in model[1].parameters():
            p.requires_grad = False

    # Resume
    if args.resume:
        model.load_state_dict(torch.load(args.resume)['model'])

    # Send to device
    model.to(args.device)

    # Loss function
    criterion = nn.CrossEntropyLoss()

    # optimizer
    optimizer = optim.Adam(model.parameters(),
                           betas=(0.9, 0.99),
                           weight_decay=args.weight_decay)

    # Scheduler
    lr_scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=args.lr,
        epochs=args.epochs,
        steps_per_epoch=len(train_loader),
        cycle_momentum=(not isinstance(optimizer, optim.Adam)),
        div_factor=args.div_factor,
        final_div_factor=args.final_div_factor)

    best_loss = math.inf
    mb = master_bar(range(args.epochs))
    for epoch_idx in mb:
        # Training
        train_loss = train_epoch(model,
                                 train_loader,
                                 optimizer,
                                 criterion,
                                 master_bar=mb,
                                 epoch=epoch_idx,
                                 scheduler=lr_scheduler,
                                 device=args.device)

        # Evaluation
        val_loss, acc = evaluate(model,
                                 val_loader,
                                 criterion,
                                 device=args.device)

        mb.comment = f"Epoch {epoch_idx+1}/{args.epochs}"
        mb.write(
            f"Epoch {epoch_idx+1}/{args.epochs} - Training loss: {train_loss:.4} | "
            f"Validation loss: {val_loss:.4} | Error rate: {1 - acc:.4}")

        # State saving
        if val_loss < best_loss:
            if args.output_dir:
                print(
                    f"Validation loss decreased {best_loss:.4} --> {val_loss:.4}: saving state..."
                )
                torch.save(
                    dict(model=model.state_dict(),
                         optimizer=optimizer.state_dict(),
                         lr_scheduler=lr_scheduler.state_dict(),
                         epoch=epoch_idx,
                         args=args),
                    Path(args.output_dir, f"{args.checkpoint}.pth"))
            best_loss = val_loss