Esempio n. 1
0
def model(num_classes: int,
          backbone: Optional[TorchvisionBackboneConfig] = None,
          remove_internal_transforms: bool = True,
          pretrained: bool = True,
          **retinanet_kwargs) -> nn.Module:
    if backbone is None:
        model = retinanet_resnet50_fpn(pretrained=pretrained,
                                       pretrained_backbone=pretrained,
                                       **retinanet_kwargs)
        model.head = RetinaNetHead(
            in_channels=model.backbone.out_channels,
            num_anchors=model.head.classification_head.num_anchors,
            num_classes=num_classes,
        )
        resnet_fpn.patch_param_groups(model.backbone)
    else:
        model = RetinaNet(backbone=backbone.backbone,
                          num_classes=num_classes,
                          **retinanet_kwargs)

    patch_retinanet_param_groups(model)

    if remove_internal_transforms:
        remove_internal_model_transforms(model)

    return model
Esempio n. 2
0
def create_retinanet(
    num_classes: int = 91,
    backbone: nn.Module = None,
    **kwargs,
):
    """
    Creates RetinaNet implementation based on torchvision library.
    Args:
    num_classes (int) : number of classes.
    Do not have class_id "0" it is reserved as background.
    num_classes = number of classes to label + 1 for background.
    """
    if backbone is None:
        model = retinanet_resnet50_fpn(
            pretrained=True,
            num_classes=91,
            **kwargs,
        )
        model.head = RetinaNetHead(
            in_channels=model.backbone.out_channels,
            num_anchors=model.head.classification_head.num_anchors,
            num_classes=num_classes,
        )
    else:
        model = RetinaNet(backbone, num_classes=num_classes, **kwargs)

    return model
Esempio n. 3
0
    def __init__(self, learning_rate: float = 0.0001, num_classes: int = 91,
                 backbone: str = None, fpn: bool = True,
                 pretrained_backbone: str = None, trainable_backbone_layers: int = 3,
                 **kwargs, ):
        """
        Args:
            learning_rate: the learning rate
            num_classes: number of detection classes (including background)
            pretrained: if true, returns a model pre-trained on COCO train2017
            pretrained_backbone (str): if "imagenet", returns a model with backbone pre-trained on Imagenet
            trainable_backbone_layers: number of trainable resnet layers starting from final block
        """
        super().__init__()
        self.learning_rate = learning_rate
        self.num_classes = num_classes
        self.backbone = backbone
        if backbone is None:
            self.model = retinanet_resnet50_fpn(pretrained=True, **kwargs)

            self.model.head = RetinaNetHead(in_channels=self.model.backbone.out_channels,
                                            num_anchors=self.model.head.classification_head.num_anchors,
                                            num_classes=num_classes, **kwargs)

        else:
            backbone_model = create_retinanet_backbone(self.backbone, fpn, pretrained_backbone,
                                                       trainable_backbone_layers, **kwargs)
            self.model = RetinaNet(backbone_model, num_classes=num_classes, **kwargs)
Esempio n. 4
0
    def __init__(
        self,
        learning_rate: float = 0.0001,
        num_classes: int = 91,
        backbone: Optional[str] = None,
        fpn: bool = True,
        pretrained: bool = False,
        pretrained_backbone: bool = True,
        trainable_backbone_layers: int = 3,
        **kwargs: Any,
    ):
        """
        Args:
            learning_rate: the learning rate
            num_classes: number of detection classes (including background)
            backbone: Pretained backbone CNN architecture.
            fpn: If True, creates a Feature Pyramind Network on top of Resnet based CNNs.
            pretrained: if true, returns a model pre-trained on COCO train2017
            pretrained_backbone: if true, returns a model with backbone pre-trained on Imagenet
            trainable_backbone_layers: number of trainable resnet layers starting from final block
        """
        super().__init__()
        self.learning_rate = learning_rate
        self.num_classes = num_classes
        self.backbone = backbone
        if backbone is None:
            self.model = retinanet_resnet50_fpn(pretrained=pretrained,
                                                **kwargs)

            self.model.head = RetinaNetHead(
                in_channels=self.model.backbone.out_channels,
                num_anchors=self.model.head.classification_head.num_anchors,
                num_classes=num_classes,
                **kwargs,
            )

        else:
            backbone_model = create_retinanet_backbone(
                self.backbone, fpn, pretrained_backbone,
                trainable_backbone_layers, **kwargs)
            self.model = torchvision_RetinaNet(backbone_model,
                                               num_classes=num_classes,
                                               **kwargs)
Esempio n. 5
0
def main(args):
    torch.cuda.set_device(0)
    random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

    # Data loading code
    print("Loading data")

    if 'voc2007' in args.dataset:
        dataset, num_classes = get_dataset(args.dataset, "trainval",
                                           get_transform(train=True),
                                           args.data_path)
        dataset_test, _ = get_dataset(args.dataset, "test",
                                      get_transform(train=False),
                                      args.data_path)
    else:
        dataset, num_classes = get_dataset(args.dataset, "train",
                                           get_transform(train=True),
                                           args.data_path)
        dataset_test, _ = get_dataset(args.dataset, "val",
                                      get_transform(train=False),
                                      args.data_path)
    print("Creating data loaders")
    num_images = len(dataset)
    if 'voc' in args.dataset:
        init_num = 1000
        budget_num = 1000
        if 'retina' in args.model:
            init_num = 1000
            budget_num = 500
    else:
        init_num = 5000
        budget_num = 1000
    indices = list(range(num_images))
    random.shuffle(indices)
    labeled_set = indices[:init_num]
    unlabeled_set = indices[init_num:]
    train_sampler = SubsetRandomSampler(labeled_set)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)
    data_loader_test = DataLoader(dataset_test,
                                  batch_size=1,
                                  sampler=test_sampler,
                                  num_workers=args.workers,
                                  collate_fn=utils.collate_fn)
    for cycle in range(args.cycles):
        if args.aspect_ratio_group_factor >= 0:
            group_ids = create_aspect_ratio_groups(
                dataset, k=args.aspect_ratio_group_factor)
            train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids,
                                                      args.batch_size)
        else:
            train_batch_sampler = torch.utils.data.BatchSampler(
                train_sampler, args.batch_size, drop_last=True)

        data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_sampler=train_batch_sampler,
            num_workers=args.workers,
            collate_fn=utils.collate_fn)

        print("Creating model")
        if 'voc' in args.dataset:
            if 'faster' in args.model:
                task_model = fasterrcnn_resnet50_fpn(num_classes=num_classes,
                                                     min_size=600,
                                                     max_size=1000)
            elif 'retina' in args.model:
                task_model = retinanet_resnet50_fpn(num_classes=num_classes,
                                                    min_size=600,
                                                    max_size=1000)
        else:
            if 'faster' in args.model:
                task_model = fasterrcnn_resnet50_fpn(num_classes=num_classes,
                                                     min_size=800,
                                                     max_size=1333)
            elif 'retina' in args.model:
                task_model = retinanet_resnet50_fpn(num_classes=num_classes,
                                                    min_size=600,
                                                    max_size=1000)
        task_model.to(device)
        if not args.init and cycle == 0 and args.skip:
            if 'faster' in args.model:
                checkpoint = torch.load(os.path.join(
                    args.first_checkpoint_path,
                    '{}_frcnn_1st.pth'.format(args.dataset)),
                                        map_location='cpu')
            elif 'retina' in args.model:
                checkpoint = torch.load(os.path.join(
                    args.first_checkpoint_path,
                    '{}_retinanet_1st.pth'.format(args.dataset)),
                                        map_location='cpu')
            task_model.load_state_dict(checkpoint['model'])
            # if 'coco' in args.dataset:
            #     coco_evaluate(task_model, data_loader_test)
            # elif 'voc' in args.dataset:
            #     voc_evaluate(task_model, data_loader_test, args.dataset)
            print("Getting stability")
            random.shuffle(unlabeled_set)
            if 'coco' in args.dataset:
                subset = unlabeled_set[:5000]
            else:
                subset = unlabeled_set
            # Update the labeled dataset and the unlabeled dataset, respectively
            labeled_set += subset[:budget_num]
            labeled_set = list(set(labeled_set))
            # with open("vis/cycle_{}.txt".format(cycle), "rb") as fp:  # Unpickling
            #     labeled_set = pickle.load(fp)
            unlabeled_set = list(set(indices) - set(labeled_set))

            # Create a new dataloader for the updated labeled dataset
            train_sampler = SubsetRandomSampler(labeled_set)
            continue
        params = [p for p in task_model.parameters() if p.requires_grad]
        task_optimizer = torch.optim.SGD(params,
                                         lr=args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.weight_decay)
        task_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            task_optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
        # Start active learning cycles training
        if args.test_only:
            if 'coco' in args.dataset:
                coco_evaluate(task_model, data_loader_test)
            elif 'voc' in args.dataset:
                voc_evaluate(task_model, data_loader_test, args.dataset)
            return
        print("Start training")
        start_time = time.time()
        for epoch in range(args.start_epoch, args.total_epochs):
            train_one_epoch(task_model, task_optimizer, data_loader, device,
                            cycle, epoch, args.print_freq)
            task_lr_scheduler.step()
            # evaluate after pre-set epoch
            if (epoch + 1) == args.total_epochs:
                if 'coco' in args.dataset:
                    coco_evaluate(task_model, data_loader_test)
                elif 'voc' in args.dataset:
                    voc_evaluate(task_model,
                                 data_loader_test,
                                 args.dataset,
                                 path=args.results_path)
        if not args.skip and cycle == 0:
            if 'faster' in args.model:
                utils.save_on_master(
                    {
                        'model': task_model.state_dict(),
                        'args': args
                    },
                    os.path.join(args.first_checkpoint_path,
                                 '{}_frcnn_1st.pth'.format(args.dataset)))
            elif 'retina' in args.model:
                utils.save_on_master(
                    {
                        'model': task_model.state_dict(),
                        'args': args
                    },
                    os.path.join(args.first_checkpoint_path,
                                 '{}_retinanet_1st.pth'.format(args.dataset)))
        random.shuffle(unlabeled_set)
        # Update the labeled dataset and the unlabeled dataset, respectively
        labeled_set += unlabeled_set[:budget_num]
        labeled_set = list(set(labeled_set))
        unlabeled_set = unlabeled_set[budget_num:]
        # Create a new dataloader for the updated labeled dataset
        train_sampler = SubsetRandomSampler(labeled_set)
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print('Training time {}'.format(total_time_str))
Esempio n. 6
0
def main(args):
    torch.cuda.set_device(0)
    random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    print(args)

    device = torch.device(args.device)

    # Data loading code
    print("Loading data")
    if 'voc2007' in args.dataset:
        dataset, num_classes = get_dataset(args.dataset, "trainval",
                                           get_transform(train=True),
                                           args.data_path)
        dataset_test, _ = get_dataset(args.dataset, "test",
                                      get_transform(train=False),
                                      args.data_path)
    else:
        dataset, num_classes = get_dataset(args.dataset, "train",
                                           get_transform(train=True),
                                           args.data_path)
        dataset_test, _ = get_dataset(args.dataset, "val",
                                      get_transform(train=False),
                                      args.data_path)
    if 'voc' in args.dataset:
        init_num = 500
        budget_num = 500
        if 'retina' in args.model:
            init_num = 1000
            budget_num = 500
    else:
        init_num = 5000
        budget_num = 1000
    print("Creating data loaders")
    num_images = len(dataset)
    indices = list(range(num_images))
    random.shuffle(indices)
    labeled_set = indices[:init_num]
    unlabeled_set = list(set(indices) - set(labeled_set))
    train_sampler = SubsetRandomSampler(labeled_set)
    unlabeled_sampler = SubsetRandomSampler(unlabeled_set)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)
    data_loader_test = DataLoader(dataset_test,
                                  batch_size=1,
                                  sampler=test_sampler,
                                  num_workers=args.workers,
                                  collate_fn=utils.collate_fn)
    for cycle in range(args.cycles):
        if args.aspect_ratio_group_factor >= 0:
            group_ids = create_aspect_ratio_groups(
                dataset, k=args.aspect_ratio_group_factor)
            train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids,
                                                      args.batch_size)
            unlabeled_batch_sampler = GroupedBatchSampler(
                unlabeled_sampler, group_ids, args.batch_size)
        else:
            train_batch_sampler = torch.utils.data.BatchSampler(
                train_sampler, args.batch_size, drop_last=True)
            unlabeled_batch_sampler = torch.utils.data.BatchSampler(
                unlabeled_sampler, args.batch_size, drop_last=True)
        data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_sampler=train_batch_sampler,
            num_workers=args.workers,
            collate_fn=utils.collate_fn)
        unlabeled_dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_sampler=unlabeled_batch_sampler,
            num_workers=args.workers,
            collate_fn=utils.collate_fn)
        print("Creating model")
        if 'voc' in args.dataset:
            if 'faster' in args.model:
                task_model = fasterrcnn_resnet50_fpn(num_classes=num_classes,
                                                     min_size=600,
                                                     max_size=1000)
            elif 'retina' in args.model:
                task_model = retinanet_resnet50_fpn(num_classes=num_classes,
                                                    min_size=600,
                                                    max_size=1000)
        else:
            if 'faster' in args.model:
                task_model = fasterrcnn_resnet50_fpn(num_classes=num_classes,
                                                     min_size=800,
                                                     max_size=1333)
            elif 'retina' in args.model:
                task_model = retinanet_resnet50_fpn(num_classes=num_classes,
                                                    min_size=800,
                                                    max_size=1333)
        task_model.to(device)

        params = [p for p in task_model.parameters() if p.requires_grad]
        task_optimizer = torch.optim.SGD(params,
                                         lr=args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.weight_decay)
        task_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            task_optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
        vae = VAE()
        params = [p for p in vae.parameters() if p.requires_grad]
        vae_optimizer = torch.optim.SGD(params,
                                        lr=args.lr / 10,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
        vae_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            vae_optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
        torch.nn.utils.clip_grad_value_(vae.parameters(), 1e5)

        vae.to(device)
        discriminator = Discriminator()
        params = [p for p in discriminator.parameters() if p.requires_grad]
        discriminator_optimizer = torch.optim.SGD(
            params,
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay)
        discriminator_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            discriminator_optimizer,
            milestones=args.lr_steps,
            gamma=args.lr_gamma)
        discriminator.to(device)
        # Start active learning cycles training
        if args.test_only:
            if 'coco' in args.dataset:
                coco_evaluate(task_model, data_loader_test)
            elif 'voc' in args.dataset:
                voc_evaluate(task_model,
                             data_loader_test,
                             args.dataset,
                             False,
                             path=args.results_path)
            return
        print("Start training")
        start_time = time.time()
        for epoch in range(args.start_epoch, args.total_epochs):
            train_one_epoch(task_model, task_optimizer, vae, vae_optimizer,
                            discriminator, discriminator_optimizer,
                            data_loader, unlabeled_dataloader, device, cycle,
                            epoch, args.print_freq)
            task_lr_scheduler.step()
            vae_lr_scheduler.step()
            discriminator_lr_scheduler.step()
            # evaluate after pre-set epoch
            if (epoch + 1) == args.total_epochs:
                if 'coco' in args.dataset:
                    coco_evaluate(task_model, data_loader_test)
                elif 'voc' in args.dataset:
                    voc_evaluate(task_model,
                                 data_loader_test,
                                 args.dataset,
                                 False,
                                 path=args.results_path)
        # Update the labeled dataset and the unlabeled dataset, respectively
        random.shuffle(unlabeled_set)
        if 'coco' in args.dataset:
            subset = unlabeled_set[:10000]
        else:
            subset = unlabeled_set
        unlabeled_loader = DataLoader(dataset,
                                      batch_size=1,
                                      sampler=SubsetSequentialSampler(subset),
                                      num_workers=args.workers,
                                      pin_memory=True,
                                      collate_fn=utils.collate_fn)
        tobe_labeled_inds = sample_for_labeling(vae, discriminator,
                                                unlabeled_loader, budget_num)
        tobe_labeled_set = [subset[i] for i in tobe_labeled_inds]
        labeled_set += tobe_labeled_set
        unlabeled_set = list(set(unlabeled_set) - set(tobe_labeled_set))
        # Create a new dataloader for the updated labeled dataset
        train_sampler = SubsetRandomSampler(labeled_set)
        unlabeled_sampler = SubsetRandomSampler(unlabeled_set)
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print('Training time {}'.format(total_time_str))