def setUp(self):
        self.path_to_frames = Path(__file__).parent / 'fixtures/'
        self.wildfire_path = Path(__file__).parent / 'fixtures/wildfire_dataset.csv'
        #self.wildfire_df = pd.read_csv(self.wildfire_path)

        self.wildfire = WildFireDataset(metadata=self.wildfire_path,
                                        path_to_frames=self.path_to_frames)
    def test_wildfire_correctly_init_from_path(self):
        for path_to_frames in [self.path_to_frames, self.path_to_frames_str]:
            wildfire = WildFireDataset(metadata=self.wildfire_path,
                                       path_to_frames=path_to_frames)

            self.assertEqual(len(wildfire), 974)
            self.assertEqual(len(wildfire[3]), 2)
    def test_wildfire_correctly_init_with_transform(self):
        wildfire = WildFireDataset(metadata=self.wildfire_path,
                                   path_to_frames=self.path_to_frames,
                                   transform=transforms.Compose([transforms.Resize((100, 66)),
                                                                 transforms.ToTensor()]))

        observation_3, metadata_3 = wildfire[3]
        self.assertEqual(observation_3.size(), torch.Size((3, 100, 66)))
    def test_wildfire_correctly_init_with_multiple_targets(self):
        wildfire = WildFireDataset(metadata=self.wildfire_df,
                                   path_to_frames=self.path_to_frames,
                                   transform=transforms.ToTensor(),
                                   target_names=['fire', 'fire_id'])

        self.assertEqual(len(wildfire), 974)

        # try to get one image of wildfire (item 3 is authorized image fixture)
        observation_3, metadata_3 = wildfire[3]
        self.assertIsInstance(observation_3, torch.Tensor)  # image correctly loaded ?
        self.assertEqual(observation_3.size(), torch.Size([3, 683, 910]))
        self.assertTrue(torch.equal(metadata_3, torch.tensor([0, 96])))  # metadata correctly loaded ?
    def test_wildfire_correctly_init_from_dataframe(self):
        for path_to_frames in [self.path_to_frames, self.path_to_frames_str]:
            wildfire = WildFireDataset(metadata=self.wildfire_df,
                                       path_to_frames=path_to_frames)

            self.assertEqual(len(wildfire), 974)
            self.assertEqual(len(wildfire[3]), 2)

        # try to get one image of wildfire (item 3 is authorized image fixture)
        observation_3, metadata_3 = wildfire[3]
        self.assertIsInstance(observation_3, PIL.Image.Image)  # image correctly loaded ?
        self.assertEqual(observation_3.size, (910, 683))
        self.assertTrue(torch.equal(metadata_3, torch.tensor([0])))  # metadata correctly loaded ?
 def test_dataloader_can_be_init_with_wildfire(self):
     wildfire = WildFireDataset(metadata=self.wildfire_path,
                                path_to_frames=self.path_to_frames)
     DataLoader(wildfire, batch_size=64)
 def test_invalid_csv_path_raises_exception(self):
     with self.assertRaises(ValueError):
         WildFireDataset(metadata='bad_path.csv',
                         path_to_frames=self.path_to_frames)
Esempio n. 8
0
def main(args):

    if args.deterministic:
        set_seed(42)

    # Set device
    if args.device is None:
        if torch.cuda.is_available():
            args.device = 'cuda:0'
        else:
            args.device = 'cpu'

    #Create Dataloaders

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    train_transforms = transforms.Compose([
        transforms.RandomResizedCrop(size=args.resize, scale=(0.8, 1.0)),
        transforms.RandomRotation(degrees=15),
        transforms.ColorJitter(),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ])

    val_transforms = transforms.Compose([
        transforms.Resize(size=args.resize),
        transforms.CenterCrop(size=args.resize),
        transforms.ToTensor(),
        normalize
    ])

    tf = {'train': train_transforms, 'test': val_transforms, 'val': val_transforms}

    metadata = pd.read_csv(args.metadata)

    wildfire = WildFireDataset(metadata=metadata, path_to_frames=Path(args.DB), target_names=['fire'])

    ratios = {'train': args.ratio_train, 'val': args.ratio_val, 'test': args.ratio_test}

    splitter = WildFireSplitter(ratios, transforms=tf)
    splitter.fit(wildfire)

    train_loader = DataLoader(splitter.train, batch_size=args.batch_size, shuffle=True)
    val_loader = DataLoader(splitter.val, batch_size=args.batch_size, shuffle=True)

    # Model definition
    model = models.__dict__[args.model](imagenet_pretrained=args.pretrained,
                                        num_classes=args.nb_class, lin_features=args.lin_feats,
                                        concat_pool=args.concat_pool, bn_final=args.bn_final,
                                        dropout_prob=args.dropout_prob)

    # Freeze layers
    if not args.unfreeze:
        # Model is sequential
        for p in model[1].parameters():
            p.requires_grad = False

    # Resume
    if args.resume:
        model.load_state_dict(torch.load(args.resume)['model'])

    # Send to device
    model.to(args.device)

    # Loss function
    criterion = nn.CrossEntropyLoss()

    # optimizer
    optimizer = optim.Adam(model.parameters(),
                           betas=(0.9, 0.99),
                           weight_decay=args.weight_decay)

    # Scheduler
    lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr,
                                                 epochs=args.epochs, steps_per_epoch=len(train_loader),
                                                 cycle_momentum=(not isinstance(optimizer, optim.Adam)),
                                                 div_factor=args.div_factor, final_div_factor=args.final_div_factor)

    best_loss = math.inf
    mb = master_bar(range(args.epochs))
    for epoch_idx in mb:
        # Training
        train_loss = train_epoch(model, train_loader, optimizer, criterion,
                                 master_bar=mb, epoch=epoch_idx, scheduler=lr_scheduler,
                                 device=args.device)

        # Evaluation
        val_loss, acc = evaluate(model, val_loader, criterion, device=args.device)

        mb.comment = f"Epoch {epoch_idx+1}/{args.epochs}"
        mb.write(f"Epoch {epoch_idx+1}/{args.epochs} - Training loss: {train_loss:.4} | "
                 f"Validation loss: {val_loss:.4} | Error rate: {1 - acc:.4}")

        # State saving
        if val_loss < best_loss:
            if args.output_dir:
                print(f"Validation loss decreased {best_loss:.4} --> {val_loss:.4}: saving state...")
                torch.save(dict(model=model.state_dict(),
                                optimizer=optimizer.state_dict(),
                                lr_scheduler=lr_scheduler.state_dict(),
                                epoch=epoch_idx,
                                args=args),
                           Path(args.output_dir, f"{args.checkpoint}.pth"))
            best_loss = val_loss