Ejemplo n.º 1
0
 def test_VOC(self):
     print('start')
     trainset = datasets.VOCDetection(root=cache_path,
                                      year='2012',
                                      image_set='train',
                                      download=False)
     trainvalset = datasets.VOCDetection(root=cache_path,
                                         year='2012',
                                         image_set='trainval',
                                         download=False)
     valset = datasets.VOCDetection(root=cache_path,
                                    year='2012',
                                    image_set='val',
                                    download=False)
     for dataset in [trainset, trainvalset, valset]:
         for index in range(0, len(dataset)):
             input, target = dataset.__getitem__(index)
             self.assertTrue(isinstance(target, dict))
             self.assertTrue('annotation' in target)
             self.assertTrue(len(target) == 1)
             self.assertTrue('object' in target['annotation'])
             for box in target['annotation']['object']:
                 for attr in [
                         'name', 'pose', 'truncated', 'occluded', 'bndbox'
                 ]:
                     self.assertTrue(attr in box)
                 tmp = box['bndbox']
                 self.assertEqual(4, len(tmp))
                 for attr in ['xmin', 'ymin', 'xmax', 'ymax']:
                     self.assertTrue(attr in tmp)
                     self.assertEqual(type(tmp[attr]), str)
Ejemplo n.º 2
0
def main():
    # Fix random seed
    torch.manual_seed(0)
    # Initialize network
    net = models.resnet50(num_classes=NUM_CLASSES)
    # Initialize loss function
    criterion = torch.nn.BCEWithLogitsLoss()
    # Prepare dataset
    train_loader = torch.utils.data.DataLoader(datasets.VOCDetection(
        './data',
        image_set='train',
        download=True,
        transform=transforms.Compose([
            transforms.Resize([480, 480]),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        target_transform=target_transform),
                                               batch_size=16,
                                               shuffle=True,
                                               num_workers=4)
    val_loader = torch.utils.data.DataLoader(datasets.VOCDetection(
        './data',
        image_set='val',
        transform=transforms.Compose([
            transforms.Resize([480, 480]),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        target_transform=target_transform),
                                             batch_size=32,
                                             num_workers=4)
    # Initialize learning engine and start training
    engine = MultiLabelClassificationEngine(net,
                                            criterion,
                                            train_loader,
                                            val_loader=val_loader,
                                            print_interval=50,
                                            optim_params={
                                                'lr': 0.1,
                                                'momentum': 0.9,
                                                'weight_decay': 5e-4
                                            })
    # Train the network for one epoch with default optimizer option
    # Checkpoints will be saved under ./checkpoints by default, containing
    # saved model parameters, optimizer statistics and progress
    engine(1)
Ejemplo n.º 3
0
def get_data_loader(dataset_name, batch_size=64, shuffle=False):
    """
    Returns a DataLoader with validation images for dataset_name
    """
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ])
    if dataset_name == 'imagenet':
        val_dataset = datasets.ImageNet(IMAGENET_VAL_DIR,
                                        split='val',
                                        transform=transform)
    elif dataset_name == 'cifar10':
        val_dataset = datasets.CIFAR10(CIFAR_VAL_DIR,
                                       train=False,
                                       transform=transform)
    elif dataset_name == 'voc2012':
        val_dataset = datasets.VOCDetection(VOC_VAL_DIR,
                                            year='2012',
                                            image_set='val',
                                            transform=transform)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=shuffle,
                                             num_workers=2)
    return val_loader
Ejemplo n.º 4
0
    def __init__(self,
                 root,
                 year,
                 image_set,
                 do_transform=True,
                 target_transform=VOCAnnotationAnalyzer(),
                 dataset_name='VOC07_12',
                 region_propose='selective_search',
                 debug=False):
        super(VOCDectectionDataset, self).__init__()
        self.datas = datasets.VOCDetection(root,
                                           str(year),
                                           image_set,
                                           download=False)
        self.image_set = image_set
        self.do_transform = do_transform
        self.name = dataset_name
        self.target_transform = target_transform  # use for annotation
        self.longer_sides = [480, 576, 688, 864, 1200]
        self.longer_sides_small = [480, 576]
        self.debug = debug
        if region_propose not in ['selective_search', 'edge_box']:
            raise NotImplementedError(f'{region_propose} not Supported')

        self.region_propose = region_propose
        self.box_mat = self.get_mat(year, image_set, region_propose)
Ejemplo n.º 5
0
 def GetTrainDataset(self):
     return datasets.VOCDetection(
         cache_path,
         year=self.year,
         image_set='train',
         download=False,
         # transform=self.GetTrainTransform()
     )
Ejemplo n.º 6
0
def get_pascal_voc2007_data(image_root, split='train'):
    from torchvision import datasets

    train_dataset = datasets.VOCDetection(image_root,
                                          year='2007',
                                          image_set=split,
                                          download=False)

    return train_dataset
Ejemplo n.º 7
0
def VOC2012_DETECTION_DATASET(root='./data',
                              train=True,
                              transform=None,
                              target_transform=None,
                              download=False):
    return datasets.VOCDetection(root,
                                 image_set='train' if train else 'val',
                                 transform=transform,
                                 target_transform=target_transform,
                                 download=download)
Ejemplo n.º 8
0
 def __init__(self, train_image_dir, test_image_dir, transform, mode):
     self.train_image_dir = train_image_dir
     self.test_image_dir = test_image_dir
     self.transform = transform
     self.mode = mode
     self.train_dataset = []
     self.test_dataset = []
     self.train_dataset = dset.VOCDetection(root=self.train_image_dir,
                                            year='2012',
                                            image_set='train',
                                            download=False)
     self.test_dataset = dset.VOCDetection(root=self.test_image_dir,
                                           year='2012',
                                           image_set='val',
                                           download=False)
     print(self.train_dataset[0])
     if self.mode == 'train':
         self.num_images = len(self.train_dataset)
     else:
         self.num_images = len(self.test_dataset)
Ejemplo n.º 9
0
def get_voc_generator(path, year, image_set, batch_size, device, shuffle=True):

    transform = transforms.Compose(
        [transforms.Resize((300, 300)),
         transforms.ToTensor()])

    def detection_collate_fn(sample_list):
        img_batch = []
        target_batch = []
        for (x, y) in sample_list:
            x_scale = 300 / x.size[0]
            y_scale = 300 / x.size[1]
            img_batch.append(transform(x))
            y = torch.stack([
                y[:, 0], y[:, 1] * x_scale, y[:, 2] * y_scale,
                y[:, 3] * x_scale, y[:, 4] * y_scale
            ],
                            dim=1)
            target_batch.append(y)
        img_batch = torch.stack(img_batch).to(device)
        return img_batch, target_batch

    def voc_target_transform(y):
        truths = []
        for elem in y["annotation"]["object"]:
            truth = torch.zeros((5, ), device=device, dtype=torch.float)
            truth[0] = VOC_DICT[elem["name"]]
            truth[1] = 0.5 * (int(elem["bndbox"]["xmax"]) +
                              int(elem["bndbox"]["xmin"]))
            truth[2] = 0.5 * (int(elem["bndbox"]["ymax"]) +
                              int(elem["bndbox"]["ymin"]))
            truth[3] = int(elem["bndbox"]["xmax"]) - int(
                elem["bndbox"]["xmin"])
            truth[4] = int(elem["bndbox"]["ymax"]) - int(
                elem["bndbox"]["ymin"])
            truths.append(truth)
        if truths:
            truth_array = torch.stack(truths, dim=0)
        else:
            truth_array = torch.zeros((0, 5), device=device, dtype=torch.float)

        return truth_array

    voc_data = dset.VOCDetection(root=path,
                                 year=year,
                                 image_set=image_set,
                                 target_transform=voc_target_transform)
    voc_generator = DataLoader(voc_data,
                               batch_size=batch_size,
                               collate_fn=detection_collate_fn,
                               shuffle=shuffle,
                               drop_last=True)
    return voc_generator
def get_pascal_voc2007_data(image_root, split='train'):
  """
  Use torchvision.datasets
  https://pytorch.org/docs/stable/torchvision/datasets.html#torchvision.datasets.VOCDetection
  """
  check_file = os.path.join(image_root, "extracted.txt")
  download = not os.path.exists(check_file)

  train_dataset = datasets.VOCDetection(image_root, year='2007', image_set=split,
                                    download=download)

  open(check_file, 'a').close()
  return train_dataset
Ejemplo n.º 11
0
def get_test_loader(data_dir,
                    batch_size,
                    shuffle=True,
                    num_workers=4,
                    pin_memory=False):
    """
    Utility function for loading and returning a multi-process
    test iterator over the CIFAR-10 dataset.
    If using CUDA, num_workers should be set to 1 and pin_memory to True.
    Params
    ------
    - data_dir: path directory to the dataset.
    - batch_size: how many samples per batch to load.
    - shuffle: whether to shuffle the dataset after every epoch.
    - num_workers: number of subprocesses to use when loading the dataset.
    - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
      True if using GPU.
    Returns
    -------
    - data_loader: test set iterator.
    """
    normalize = transforms.Normalize(
        mean=[0.457342265910642, 0.4387686270106377, 0.4073427106250871],
        std=[0.26753769276329037, 0.2638145880487105, 0.2776826934044154],
    )

    # define transform
    transform = transforms.Compose([
        transforms.Resize(330),
        transforms.CenterCrop(300),
        transforms.ToTensor(),
        normalize,
    ])

    dataset = datasets.VOCDetection(
        root=data_dir,
        image_set='val',
        year='2012',
        download=True,
        transform=transform,
    )

    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )

    return data_loader, len(dataset)
Ejemplo n.º 12
0
def load_dataset(root,
                 transform,
                 *args,
                 batch_size=32,
                 shuffle=True,
                 dataset_type='folder',
                 **kwargs):
    """
    Parameters
    -----------

    dataset_type: str
        should be voc , coco, cifar, minst or folder
        if you're using voc dataset then you have to pass a param as year = 2007 or 2012
        if you're using coco dataset then you have to pass a param as type = 'detection' or 'caption'
    
    Return
    ----------
    data: Dataloader

    dataset: torchvision.dataset

    """
    if dataset_type == 'folder':
        dataset = datasets.ImageFolder(root, transform=transform)

    elif dataset_type == 'voc':
        year = kwargs.get('year', 2007)
        image_set = kwargs.get('image_set', 'train')
        dataset = datasets.VOCDetection(root,
                                        year=year,
                                        image_set=image_set,
                                        transform=transform)

    elif dataset_type == 'coco':
        assert 'type' in kwargs and 'annfile' in kwargs
        annfile = kwargs['annfile']
        type = kwargs['type']
        if type == 'detection':
            dataset = datasets.CocoDetection(root,
                                             annFile=annfile,
                                             transform=transform)
        elif type == 'caption':
            dataset = datasets.CocoCaptions(root,
                                            annFile=annfile,
                                            transform=transform)

    data = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

    return data, dataset
Ejemplo n.º 13
0
    def __init__(self,
                 data_root: str,
                 datatype: str = "train",
                 transforms=None):
        self.data_root = data_root
        self.transforms = transforms
        self.splitImgPath = data_root + "VOCdevkit/VOC2007/ImageSets/Main/" + datatype + ".txt"
        with open(self.splitImgPath, "r") as splitIdx:
            self.imgNames = splitIdx.readlines()

        # self.dataset = datasets.VOCDetection(data_root, year='2007', image_set = datatype, transform=self.transforms)
        self.dataset = datasets.VOCDetection(data_root,
                                             year='2007',
                                             image_set=datatype)
Ejemplo n.º 14
0
    def get_dataset(self):
        data_augment = self._get_simclr_pipeline_transform()

        if self.name == 'stl10':
            return datasets.STL10('./data',
                                  split='train+unlabeled',
                                  download=True,
                                  transform=SimCLRDataTransform(data_augment))
        elif self.name == 'pascal-voc':
            return datasets.VOCDetection(
                './data',
                image_set='trainval',
                download=True,
                transform=SimCLRDataTransform(data_augment),
                target_transform=NoneTargetTransform())
        else:
            raise ValueError('Unsupported dataset name: %s' % self.name)
Ejemplo n.º 15
0
 def __init__(self, root, year, image_set,
              target_transform=VOCAnnotationAnalyzer(),
              dataset_name='VOC07_12',
              region_propose='selective_search',
              use_corloc=False,
              debug=False,
              small_box=True,
              over_box=True):
     super(VOCDectectionDataset, self).__init__()
     self.datas = datasets.VOCDetection(root, str(year), image_set, download=False)
     self.image_set = image_set
     self.name = dataset_name
     self.target_transform = target_transform # use for annotation
     self.debug = debug
     self.region_propose = region_propose
     self.box_mat = self.get_mat(year, image_set, region_propose)
     self.use_corloc = use_corloc
     self.small_box = small_box
     self.over_box = over_box
Ejemplo n.º 16
0
def load_dataset(root,
                 transform,
                 batch_size=32,
                 shuffle=True,
                 dataset_type='folder',
                 *args,
                 **kwargs):
    """
    param
    dataset_type: str
        should be voc , coco, cifar, minst or folder
    
    """
    if dataset_type == 'folder':
        dataset = datasets.ImageFolder(root, transform=transform)

    elif dataset_type == 'voc':
        year = kwargs['year']
        image_set = kwargs['image_set']
        dataset = datasets.VOCDetection(root,
                                        year=year,
                                        image_set=image_set,
                                        transform=transform)
    elif dataset_type == 'coco':
        annfile = kwargs['annfile']
        type = kwargs['type']
        if type == 'detect':
            dataset = datasets.CocoDetection(root,
                                             annFile=annfile,
                                             transform=transform)
        elif type == 'caption':
            dataset = datasets.CocoCaptions(root,
                                            annFile=annfile,
                                            transform=transform)

    data = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

    return data, dataset.classes, dataset.class_to_idx
Ejemplo n.º 17
0
def _load_dataset(root='',
                  name='cifar10',
                  train=True,
                  download=True,
                  transform=None,
                  from_folder=''):
    """ Initialize dataset from torchvision or from folder

    Args:
        root: (str) Directory where dataset is stored
        name: (str) Name of the dataset (e.g. cifar10, cifar100)
        train: (bool) Use the training set
        download: (bool) Download the dataset
        transform: (torchvision.transforms.Compose) image transformations
        from_folder: (str) Path to directory holding the images to load.

    Returns:
        A torchvision dataset

    Raises:
        ValueError: If the specified dataset doesn't exist

    """

    if from_folder and os.path.exists(from_folder):
        # load data from directory
        dataset = _load_dataset_from_folder(from_folder,
                                            transform)

    elif name.lower() == 'cifar10' and root:
        # load cifar10
        dataset = datasets.CIFAR10(root,
                                   train=train,
                                   download=download,
                                   transform=transform)

    elif name.lower() == 'cifar100' and root:
        # load cifar100
        dataset = datasets.CIFAR100(root,
                                    train=train,
                                    download=download,
                                    transform=transform)

    elif name.lower() == 'cityscapes' and root:
        # load cityscapes
        root = os.path.join(root, 'cityscapes/')
        split = 'train' if train else 'val'
        dataset = datasets.Cityscapes(root,
                                      split=split,
                                      transform=transform)

    elif name.lower() == 'stl10' and root:
        # load stl10
        split = 'train' if train else 'test'
        dataset = datasets.STL10(root,
                                 split=split,
                                 download=download,
                                 transform=transform)

    elif name.lower() == 'voc07-seg' and root:
        # load pascal voc 07 segmentation dataset
        image_set = 'train' if train else 'val'
        dataset = datasets.VOCSegmentation(root,
                                           year='2007',
                                           image_set=image_set,
                                           download=download,
                                           transform=transform)

    elif name.lower() == 'voc12-seg' and root:
        # load pascal voc 12 segmentation dataset
        image_set = 'train' if train else 'val'
        dataset = datasets.VOCSegmentation(root,
                                           year='2012',
                                           image_set=image_set,
                                           download=download,
                                           transform=transform)

    elif name.lower() == 'voc07-det' and root:
        # load pascal voc 07 object detection dataset
        image_set = 'train' if train else 'val'
        dataset = datasets.VOCDetection(root,
                                        year='2007',
                                        image_set=image_set,
                                        download=True,
                                        transform=transform)

    elif name.lower() == 'voc12-det' and root:
        # load pascal voc 12 object detection dataset
        image_set = 'train' if train else 'val'
        dataset = datasets.VOCDetection(root,
                                        year='2012',
                                        image_set=image_set,
                                        download=True,
                                        transform=transform)

    else:
        raise ValueError(
            'The specified dataset (%s) or datafolder (%s) does not exist '
            % (name, from_folder))

    return dataset
Ejemplo n.º 18
0
def load_datasets(min_shape=(600, 600),
                  max_shape=(1000, 1000),
                  sub_sample=16,
                  ceil_mode=False,
                  pad_to_max=True,
                  stretch_to_max=False):
    mean_val = [0.485, 0.456, 0.406]
    std_val = [0.229, 0.224, 0.225]
    train_transforms_list = [
        RandomHorizontalFlip(0.5),
        DynamicResize(min_shape, max_shape, stretch_to_max=stretch_to_max),
        StandardTransform(tvt.ToTensor()),
        StandardTransform(tvt.Normalize(mean_val, std_val)),
    ]
    val_transforms_list = [
        DynamicResize(min_shape, max_shape, stretch_to_max=stretch_to_max),
        StandardTransform(tvt.ToTensor()),
        StandardTransform(tvt.Normalize(mean_val, std_val)),
    ]
    if pad_to_max:
        train_transforms_list.append(StandardTransform(PadToShape(max_shape)))
        val_transforms_list.append(StandardTransform(PadToShape(max_shape)))
    train_transforms_list.append(
        CreateRPNLabels(sub_sample=sub_sample, ceil_mode=ceil_mode))
    val_transforms_list.append(
        CreateRPNLabels(sub_sample=sub_sample, ceil_mode=ceil_mode))

    if args.dataset == 'coco':
        print('Loading MSCOCO Detection dataset')
        train_transforms_list.insert(0, FormatCOCOLabels())
        val_transforms_list.insert(0, FormatCOCOLabels())

        train_transforms = ComposeTransforms(train_transforms_list)
        val_transforms = ComposeTransforms(val_transforms_list)

        train_dataset = CocoDetectionWithImgId(args.coco_root,
                                               image_set='train',
                                               download=True,
                                               transforms=train_transforms)
        val_dataset = CocoDetectionWithImgId(args.coco_root,
                                             image_set='val',
                                             download=True,
                                             transforms=val_transforms)
    elif args.dataset == 'voc':
        print('Loading Pascal VOC 2007 Detection dataset')
        train_transforms_list.insert(0, FormatVOCLabels(use_difficult=False))
        val_transforms_list.insert(0, FormatVOCLabels(use_difficult=True))

        train_transforms = ComposeTransforms(train_transforms_list)
        val_transforms = ComposeTransforms(val_transforms_list)

        download = not os.path.exists(
            os.path.join(args.voc_root, 'VOCdevkit/VOC2007'))
        train_dataset = torch.utils.data.ConcatDataset([
            datasets.VOCDetection(args.voc_root,
                                  year='2007',
                                  download=download,
                                  image_set='train',
                                  transforms=train_transforms),
            datasets.VOCDetection(args.voc_root,
                                  year='2007',
                                  download=download,
                                  image_set='val',
                                  transforms=train_transforms)
        ])

        val_dataset = datasets.VOCDetection(args.voc_root,
                                            year='2007',
                                            download=download,
                                            image_set='test',
                                            transforms=val_transforms)
    else:
        raise ValueError

    train_sampler = torch.utils.data.RandomSampler(train_dataset)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.train_batch_size,
        sampler=train_sampler,
        num_workers=args.num_workers,
        drop_last=True,
        collate_fn=faster_rcnn_collate_fn,
        pin_memory=True)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.test_batch_size,
                                             shuffle=False,
                                             num_workers=args.num_workers,
                                             drop_last=False,
                                             collate_fn=faster_rcnn_collate_fn,
                                             pin_memory=True)

    return train_loader, val_loader
Ejemplo n.º 19
0
def finetune(data_dir,
             checkpoint_path,
             arch='vgg16',
             dataset='voc_2007',
             ann_dir=None,
             train_split='train',
             val_split='val',
             input_size=224,
             optimizer_name='SGD',
             lr=1e-2,
             epochs=100,
             batch_size=64,
             workers=8,
             resume_checkpoint_path=None,
             should_validate=False):
    """
    Finetune the last layer of a CNN pretrained on ImageNet.

    Args:
        data_dir: String, path to root directory containing image files.
        checkpoint_path: String, path to save checkpoint.
        arch: String, name of torchvision.models architecture.
        dataset: String, name of dataset.
        ann_dir: String, path to root directory containing annotation files
            (used for COCO).
        train_split: String, name of split to use for training.
        val_split: String, name of split to use for validation.
        input_size: Integer, length of the side of the input image.
        optimizer_name: String, name of torch.optim.Optimizer to use.
        lr: Float, learning rate to use.
        epochs: Integer, number of epochs to train for.
        batch_size: Integer, batch size to use.
        workers: Integer, number of workers to use for loading data.
        resume_checkpoint_path: String, checkpoint to resume training from.
        should_validate: Boolean, if True, validate model (no training).
    """
    # Load model, replacing the last layer.
    model = get_finetune_model(arch=arch, dataset=dataset)

    # Prepare data augmentation.
    # Use caffe normalization for GoogLeNet.
    if arch == 'googlenet':
        normalize = GoogLeNetNormalize()
    else:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

    train_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.RandomResizedCrop(input_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        normalize,
    ])

    # Prepare data loaders.
    if 'voc' in dataset:
        year = dataset.split('_')[-1]

        target_transform = transforms.Compose([
            FromVOCToOneHotEncoding(),
            SimpleToTensor(),
        ])

        train_dset = datasets.VOCDetection(data_dir,
                                           year=year,
                                           image_set=train_split,
                                           transform=train_transform,
                                           target_transform=target_transform)
        val_dset = datasets.VOCDetection(data_dir,
                                         year=year,
                                         image_set=val_split,
                                         transform=val_transform,
                                         target_transform=target_transform)
    elif 'coco' in dataset:
        train_ann_path = os.path.join(ann_dir,
                                      'instances_%s.json' % train_split)
        val_ann_path = os.path.join(ann_dir, 'instances_%s.json' % val_split)

        target_transform = transforms.Compose([
            FromCocoToOneHotEncoding(),
            SimpleToTensor(),
        ])
        train_dset = datasets.CocoDetection(os.path.join(
            data_dir, train_split),
                                            train_ann_path,
                                            transform=train_transform,
                                            target_transform=target_transform)
        val_dset = datasets.CocoDetection(os.path.join(data_dir, val_split),
                                          val_ann_path,
                                          transform=val_transform,
                                          target_transform=target_transform)
    else:
        assert (False)

    train_loader = torch.utils.data.DataLoader(train_dset,
                                               batch_size=batch_size,
                                               num_workers=workers,
                                               shuffle=True,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dset,
                                             batch_size=batch_size,
                                             num_workers=workers,
                                             shuffle=False,
                                             pin_memory=True)

    # Define loss criterion.
    criterion = nn.BCEWithLogitsLoss()

    # Move model to GPU or CPU.
    device = get_device()
    model = model.to(device)
    criterion = criterion.to(device)

    # Prepare optimizer.
    optimizer = torch.optim.__dict__[optimizer_name](model.parameters(), lr=lr)

    # Restore previous checkpoint, if provided.
    if resume_checkpoint_path is not None:
        checkpoint = torch.load(resume_checkpoint_path)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
    else:
        start_epoch = 0

    # Validate model.
    if should_validate:
        validate(val_loader, model, criterion, device)
        return

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)

    best_val_loss = np.inf
    for epoch in range(start_epoch, epochs):
        curr_lr = optimizer.param_groups[0]['lr']
        if curr_lr != lr:
            print('Breaking early at %d epochs because training plateaued.' %
                  epochs)
            break
        train(train_loader, model, criterion, optimizer, epoch, device)
        val_loss, prec, rec = validate(val_loader, model, criterion, device)
        scheduler.step(val_loss)

        # Save checkpoint.
        is_best = val_loss < best_val_loss
        best_val_loss = min(val_loss, best_val_loss)
        res = {
            'epoch': epoch + 1,
            'arch': arch,
            'dataset': dataset,
            'data_dir': data_dir,
            'input_size': input_size,
            'lr': lr,
            'prec': prec,
            'rec': rec,
            'epochs': epochs,
            'batch_size': batch_size,
            'state_dict': model.state_dict(),
            'best_val_loss': best_val_loss,
            'optimizer': optimizer.state_dict()
        }
        save_checkpoint(res, checkpoint_path, is_best=is_best)
Ejemplo n.º 20
0
def main(seed_size=None,
         train_batch_size=None,
         test_batch_size=None,
         num_of_data_to_add=None,
         sampling_parameter=None):
    directory = './data'
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Selected device: " + str(device))
    if not os.path.exists(directory):
        os.mkdir(directory)
    mean = [0.457342265910642, 0.4387686270106377, 0.4073427106250871]
    std = [0.26753769276329037, 0.2638145880487105, 0.2776826934044154]
    transformations = transforms.Compose([
        transforms.Resize((300, 300)),
        transforms.RandomChoice([
            transforms.ColorJitter(brightness=(0.80, 1.20)),
            transforms.RandomGrayscale(p=0.25)
        ]),
        transforms.RandomHorizontalFlip(p=0.25),
        transforms.RandomRotation(25),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])
    transformations_test = transforms.Compose([
        transforms.Resize(330),
        transforms.FiveCrop(300),
        transforms.Lambda(lambda crops: torch.stack(
            [transforms.ToTensor()(crop) for crop in crops])),
        transforms.Lambda(lambda crops: torch.stack(
            [transforms.Normalize(mean=mean, std=std)(crop)
             for crop in crops])),
    ])
    train_set = datasets.VOCDetection(directory,
                                      year="2007",
                                      image_set="train",
                                      transform=transformations,
                                      target_transform=create_label,
                                      download=True)
    test_set = datasets.VOCDetection(directory,
                                     year="2007",
                                     image_set="val",
                                     transform=transformations_test,
                                     target_transform=create_label,
                                     download=True)
    if not seed_size:
        seed_size = int(
            input("Enter desired number of data in initial seed: "))
    if not train_batch_size:
        train_batch_size = int(input("Enter size of batch for training: "))
    if not test_batch_size:
        test_batch_size = int(input("Enter size of batch for validation: "))
    if not num_of_data_to_add:
        num_of_data_to_add = int(
            input(
                "Enter number of new data to add to the train set in each epoch: "
            ))
    while True:
        if not sampling_parameter:
            method = input(
                "Choose sampling method for unlabeled data from the available methods:\n"
                "1 - Least Confidence Sampling\n"
                "2 - Margin Sampling\n"
                "3 - Entropy Sampling\n"
                "Enter one of the numbers: ")
        else:
            method = sampling_parameter
        try:
            number = int(method)
            if number == 1:
                sampling_method = least_confidence_sampling
                break
            elif number == 2:
                sampling_method = margin_sampling
                break
            elif number == 3:
                sampling_method = entropy_sampling
                break
            else:
                raise ValueError
        except ValueError:
            print("Wrong argument entered.")
    labeled_data_indexes = list()
    while len(labeled_data_indexes) < seed_size:
        index = randint(0, len(train_set) - 1)
        if index not in labeled_data_indexes:
            labeled_data_indexes.append(index)
    labeled_dataset = CustomLabeledDataset(train_set, labeled_data_indexes)
    unlabeled_data_indexes = list()
    for index in range(0, len(train_set)):
        if index not in labeled_data_indexes:
            unlabeled_data_indexes.append(index)
    unlabeled_dataset = CustomUnlabeledDataset(train_set,
                                               unlabeled_data_indexes)
    labeled_loader = data.DataLoader(labeled_dataset,
                                     batch_size=train_batch_size,
                                     shuffle=True,
                                     pin_memory=True,
                                     num_workers=multiprocessing.cpu_count())
    unlabeled_loader = data.DataLoader(unlabeled_dataset,
                                       batch_size=train_batch_size,
                                       shuffle=True,
                                       num_workers=multiprocessing.cpu_count())
    test_loader = data.DataLoader(dataset=test_set,
                                  batch_size=test_batch_size,
                                  pin_memory=True)

    model = models.resnet18(pretrained=True)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    model.fc = nn.Linear(model.fc.in_features, 20)
    model.to(device)
    loss_func = nn.BCEWithLogitsLoss(reduction='sum')
    optimizer = optim.SGD([{
        'params': list(model.parameters())[:-1],
        'lr': 1e-5,
        'momentum': 0.9
    }, {
        'params': list(model.parameters())[-1],
        'lr': 5e-3,
        'momentum': 0.9
    }])
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                     12,
                                                     eta_min=0,
                                                     last_epoch=-1)

    m = nn.Sigmoid()
    sizes_of_data = list()
    average_precisions = list()

    print("Started training on initial seed.")
    for epoch in range(15):
        print("Started training in epoch %d." % (epoch + 1))
        for data_labeled in labeled_loader:
            inputs, labels = data_labeled[0].to(
                device), data_labeled[1].float().to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_func(outputs, labels)
            loss.backward()
            optimizer.step()
        scheduler.step()
        print("Finished training in epoch %d." % (epoch + 1))
    print("Finished training on initial seed.")
    print("Evaluating model on train set.")
    average_precision = 0.0
    # average_loss = 0.0
    with torch.no_grad():
        for test_data in test_loader:
            inputs, labels = test_data[0].to(device), test_data[1].float().to(
                device)
            bs, ncrops, c, h, w = inputs.size()
            outputs = model(inputs.view(-1, c, h, w))
            outputs = outputs.view(bs, ncrops, -1).mean(1)
            # average_loss += loss_func(outputs, labels).item()
            average_precision += get_average_precision(
                torch.Tensor.cpu(labels).detach().numpy(),
                torch.Tensor.cpu(m(outputs)).detach().numpy())
    average_precision = round(average_precision / len(test_set), 2)
    # average_loss = average_loss / len(test_set)
    sizes_of_data.append(len(labeled_dataset))
    average_precisions.append(average_precision)
    print("Evaluation complete.")
    # print("Average precision of neural network on %d samples is: %.2f%%" % (len(labeled_dataset),
    #                                                                         average_precision))
    # print("Average loss of neural network on %d samples is: %f" % (len(labeled_dataset), average_loss))

    for iteration in range(10):
        print("Started training in iteration %d." % (iteration + 1))
        if len(unlabeled_dataset) > 0:
            labeled_dataset.list_of_indexes += sampling_method(
                num_of_data_to_add, unlabeled_loader, model, device)
        for epoch in range(15):
            print("Started training in epoch %d." % (epoch + 1))
            for i, data_labeled in enumerate(labeled_loader, 1):
                inputs, labels = data_labeled[0].to(
                    device), data_labeled[1].float().to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = loss_func(outputs, labels)
                loss.backward()
                optimizer.step()
                # print("Epoch: %d, batch: %d, loss: %f" % (epoch + 1, i, loss.item()))
            scheduler.step()
            print("Finished training in epoch %d." % (epoch + 1))
        print("Finished training in iteration %d." % (iteration + 1))
        print("Evaluating model on train set.")
        average_precision = 0.0
        # average_loss = 0.0
        with torch.no_grad():
            for test_data in test_loader:
                inputs, labels = test_data[0].to(
                    device), test_data[1].float().to(device)
                bs, ncrops, c, h, w = inputs.size()
                outputs = model(inputs.view(-1, c, h, w))
                outputs = outputs.view(bs, ncrops, -1).mean(1)
                # average_loss += loss_func(outputs, labels).item()
                average_precision += get_average_precision(
                    torch.Tensor.cpu(labels).detach().numpy(),
                    torch.Tensor.cpu(m(outputs)).detach().numpy())
        average_precision = round(average_precision / len(test_set), 2)
        # average_loss = average_loss / len(test_set)
        sizes_of_data.append(len(labeled_dataset))
        average_precisions.append(average_precision)
        print("Evaluation complete.")
        # print("Average precision of neural network on %d samples is: %.2f%%" % (len(labeled_dataset),
        #                                                                         average_precision))
        # print("Average loss of neural network on %d samples is: %f" % (len(labeled_dataset), average_loss))

    return sizes_of_data, average_precisions
Ejemplo n.º 21
0
denormalize = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],
                                                     std = [ 1/0.229, 1/0.224, 1/0.225 ]),
                                transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],
                                                     std = [ 1., 1., 1. ]),
                               ])

#print(dir(datasets))
#print(torchvision.__file__)



# choose the training and test datasets
#train_data = datasets.VOCDetection('VOCDetectionData', image_set = 'train',
#                              download=True, transform=augmented_transform)
test_data = datasets.VOCDetection('VOCDetectionData', image_set = 'val',
                              download=True, transform=base_transform)

#test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, num_workers=num_workers)

print(len(test_data))

for i in range(len(test_data)):
    img, target = test_data.__getitem__(i)


    image = img
    image = denormalize(image)
    image = image.numpy()
    image = image.transpose(1,2,0)
    image = image.clip(0, 1)
    image = cv2.cvtColor(image,cv2.COLOR_RGB2BGR)
Ejemplo n.º 22
0
def main(num_of_data=None, batch_size_train=None, batch_size_test=None):
    directory = './data'
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Selected device: " + str(device))
    if not os.path.exists(directory):
        os.mkdir(directory)
    mean = [0.457342265910642, 0.4387686270106377, 0.4073427106250871]
    std = [0.26753769276329037, 0.2638145880487105, 0.2776826934044154]
    transformations = transforms.Compose([
        transforms.Resize((300, 300)),
        transforms.RandomChoice([
            transforms.ColorJitter(brightness=(0.80, 1.20)),
            transforms.RandomGrayscale(p=0.25)
        ]),
        transforms.RandomHorizontalFlip(p=0.25),
        transforms.RandomRotation(25),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])
    transformations_test = transforms.Compose([
        transforms.Resize(330),
        transforms.FiveCrop(300),
        transforms.Lambda(lambda crops: torch.stack(
            [transforms.ToTensor()(crop) for crop in crops])),
        transforms.Lambda(lambda crops: torch.stack(
            [transforms.Normalize(mean=mean, std=std)(crop)
             for crop in crops])),
    ])
    train_set = datasets.VOCDetection(directory,
                                      year="2007",
                                      image_set="train",
                                      transform=transformations,
                                      target_transform=create_label,
                                      download=True)
    test_set = datasets.VOCDetection(directory,
                                     year="2007",
                                     image_set="val",
                                     transform=transformations_test,
                                     target_transform=create_label,
                                     download=True)
    if not num_of_data:
        num_of_data = int(input("Enter number of data to train on: "))
    if num_of_data == len(train_set):
        train_dataset = train_set
    else:
        indexes = list()
        while len(indexes) < num_of_data:
            index = randint(0, len(train_set) - 1)
            if index not in indexes:
                indexes.append(index)
        train_dataset = CustomLabeledDataset(train_set, indexes)
    if not batch_size_train:
        batch_size_train = int(
            input("Enter desired batch size for training: "))
    if not batch_size_test:
        batch_size_test = int(
            input("Enter desired batch size for validation: "))
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=batch_size_train,
        shuffle=True,
        pin_memory=True,
        num_workers=multiprocessing.cpu_count())
    test_loader = torch.utils.data.DataLoader(dataset=test_set,
                                              batch_size=batch_size_test,
                                              pin_memory=True)

    model = models.resnet18(pretrained=True)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    model.fc = nn.Linear(model.fc.in_features, 20)
    model.to(device)
    loss_func = nn.BCEWithLogitsLoss(reduction='sum')
    optimizer = optim.SGD([{
        'params': list(model.parameters())[:-1],
        'lr': 1e-5,
        'momentum': 0.9
    }, {
        'params': list(model.parameters())[-1],
        'lr': 5e-3,
        'momentum': 0.9
    }])
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                     12,
                                                     eta_min=0,
                                                     last_epoch=-1)

    print("Started training.")

    for epoch in range(15):
        for i, data in enumerate(train_loader, 1):
            inputs, labels = data[0].to(device), data[1].float().to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_func(outputs, labels)
            loss.backward()
            optimizer.step()
            # print("Epoch: %d, batch: %d, loss: %f" % (epoch + 1, i, loss.item()))
        scheduler.step()

    print("Finished Training.")
    print("Evaluating model on train set.")

    average_precision = 0.0
    average_loss = 0.0
    m = nn.Sigmoid()
    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data[0].to(device), data[1].float().to(device)
            bs, ncrops, c, h, w = inputs.size()
            outputs = model(inputs.view(-1, c, h, w))
            outputs = outputs.view(bs, ncrops, -1).mean(1)
            average_loss += loss_func(outputs, labels).item()
            average_precision += get_average_precision(
                torch.Tensor.cpu(labels).detach().numpy(),
                torch.Tensor.cpu(m(outputs)).detach().numpy())

    # print("Average precision of neural network is: %.2f%%" % (100 * (average_precision / len(test_set))))
    # print("Average loss of neural network is: %f" % (average_loss / len(test_set)))
    return len(train_dataset), round(average_precision / len(test_set), 2)
Ejemplo n.º 23
0
def get_train_valid_loader(data_dir,
                           batch_size,
                           augment,
                           random_seed,
                           valid_size=0.1,
                           shuffle=True,
                           show_sample=False,
                           num_workers=4,
                           pin_memory=False):
    """
    Utility function for loading and returning train and valid
    multi-process iterators over the CIFAR-10 dataset. A sample
    9x9 grid of the images can be optionally displayed.
    If using CUDA, num_workers should be set to 1 and pin_memory to True.
    Params
    ------
    - data_dir: path directory to the dataset.
    - batch_size: how many samples per batch to load.
    - augment: whether to apply the data augmentation scheme
      mentioned in the paper. Only applied on the train split.
    - random_seed: fix seed for reproducibility.
    - valid_size: percentage split of the training set used for
      the validation set. Should be a float in the range [0, 1].
    - shuffle: whether to shuffle the train/validation indices.
    - show_sample: plot 9x9 sample grid of the dataset.
    - num_workers: number of subprocesses to use when loading the dataset.
    - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
      True if using GPU.
    Returns
    -------
    - train_loader: training set iterator.
    - valid_loader: validation set iterator.
    """
    error_msg = "[!] valid_size should be in the range [0, 1]."
    assert ((valid_size >= 0) and (valid_size <= 1)), error_msg

    normalize = transforms.Normalize(
        mean=[0.457342265910642, 0.4387686270106377, 0.4073427106250871],
        std=[0.26753769276329037, 0.2638145880487105, 0.2776826934044154],
    )

    # define transforms
    valid_transform = transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])
    if augment:
        train_transform = transforms.Compose([
            transforms.Resize((300, 300)),
            transforms.RandomChoice([
                transforms.ColorJitter(brightness=(0.80, 1.20)),
                transforms.RandomGrayscale(p=0.25)
            ]),
            transforms.RandomHorizontalFlip(p=0.25),
            transforms.RandomRotation(25),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        train_transform = transforms.Compose([
            transforms.Resize(330),
            transforms.CenterCrop(300),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std),
        ])

    # load the dataset
    train_dataset = datasets.VOCDetection(
        root=data_dir,
        image_set='train',
        year='2012',
        download=True,
        transform=train_transform,
    )

    valid_dataset = datasets.VOCDetection(
        root=data_dir,
        image_set='trainval',
        year='2012',
        download=True,
        transform=valid_transform,
    )

    num_train = len(train_dataset)
    indices = list(range(num_train))
    split = int(np.floor(valid_size * num_train))

    if shuffle:
        np.random.seed(random_seed)
        np.random.shuffle(indices)

    train_idx, valid_idx = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )

    # visualize some images
    if show_sample:
        sample_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=9,
            shuffle=shuffle,
            num_workers=num_workers,
            pin_memory=pin_memory,
        )
        data_iter = iter(sample_loader)
        images, labels = data_iter.next()
        X = images.numpy().transpose([0, 2, 3, 1])
        plot_images(X, labels)

    return (train_loader, valid_loader)
Ejemplo n.º 24
0
def main_worker(gpu, ngpus_per_node, logger, args):
    global best_acc1, metrics

    assert ".pth.tar" in args.checkpoint

    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)
    # create model
    model = get_model(args.arch,
                      dataset=args.dataset,
                      pretrained=args.pretrained)

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(args.workers / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = get_criterion(args.dataset).cuda(args.gpu)

    if args.finetune_last_layer:
        last_name, last_module = get_last_layer(model)
        optimizer = torch.optim.SGD(
            [
                {
                    "params": [
                        param for name, param in model.named_parameters()
                        if name != last_name
                    ],
                    "lr":
                    args.lr * args.finetune_decay_factor,
                },
                {
                    "params": last_module.parameters()
                },
            ],
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
        )
        pass
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

    # resume checkpoint if available
    if os.path.isfile(args.checkpoint):
        print("=> loading checkpoint '{}'".format(args.checkpoint))
        checkpoint = torch.load(args.checkpoint)
        start_iter = checkpoint['iter']
        args.start_epoch = checkpoint['epoch']
        best_acc1 = checkpoint['best_acc1']
        metrics = checkpoint['metrics']
        if args.gpu is not None:
            # best_acc1 may be from a checkpoint from a different GPU
            best_acc1 = best_acc1.to(args.gpu)
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            args.checkpoint, checkpoint['epoch']))
    else:
        start_iter = 0
        args.start_epoch = 0
        print("=> no checkpoint found at '{}'; creating a new one...".format(
            args.checkpoint))
        save_checkpoint(
            {
                'epoch': args.start_epoch,
                'arch': args.arch,
                'iter': start_iter,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
                'metrics': metrics,
                'args': args,
            }, False, args.checkpoint)

    cudnn.benchmark = True

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    if args.dataset == "imagenet":
        traindir = os.path.join(args.data, args.train_split)
        valdir = os.path.join(args.data, args.val_split)

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        val_dataset = datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ]))
    elif args.dataset == "pascal":
        train_dataset = datasets.VOCDetection(
            args.data,
            year=args.year,
            image_set=args.train_split,
            download=args.download,
            transform=transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]),
            target_transform=custom_transforms.VOCToClassVector())

        val_dataset = datasets.VOCDetection(
            args.data,
            year=args.year,
            image_set=args.val_split,
            download=False,
            transform=transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ]),
            target_transform=custom_transforms.VOCToClassVector())
    else:
        raise NotImplementedError("{} dataset not supported; only {}.".format(
            args.dataset, ' | '.join(dataset_names)))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion, args.start_epoch, logger, args)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, args)

        # save current learning rate
        curr_lr = optimizer.param_groups[0]['lr']
        metrics['lr'][epoch] = curr_lr
        if logger is not None:
            logger.add_scalar('train/lr', curr_lr, epoch)

        # initialize start iteration
        start_iter_for_epoch = start_iter if epoch == args.start_epoch else 0

        # train for one epoch
        train_acc1, train_acc5, train_loss = train(
            train_loader,
            model,
            criterion,
            optimizer,
            epoch,
            logger,
            args,
            start_iter=start_iter_for_epoch)

        # evaluate on validation set
        val_acc1, val_acc5, val_loss = validate(val_loader, model, criterion,
                                                epoch, logger, args)

        metrics['train/acc1'][epoch] = train_acc1
        metrics['train/acc5'][epoch] = train_acc5
        metrics['train/loss'][epoch] = train_loss
        metrics['val/acc1'][epoch] = val_acc1
        metrics['val/acc5'][epoch] = val_acc5
        metrics['val/loss'][epoch] = val_loss

        # remember best acc@1 and save checkpoint
        is_best = val_acc1 > best_acc1
        best_acc1 = max(val_acc1, best_acc1)

        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'iter': 0,
                    'state_dict': model.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer': optimizer.state_dict(),
                    'metrics': metrics,
                    'args': args,
                }, is_best, args.checkpoint)
Ejemplo n.º 25
0
def learn_masks_for_pointing(data_dir,
                             checkpoint_path,
                             arch='vgg16',
                             dataset='voc_2007',
                             ann_dir=None,
                             split='test',
                             input_size=224):
    """
    Learn explanatory masks for the pointing game.

    Args:
        data_dir: String, path to root directory for dataset.
        checkpoint_path: String, path to checkpoint.
        arch: String, name of torchvision.models architecture.
        dataset: String, name of dataset.
        ann_dir: String, path to root directory containing annotation files
            (used for COCO).
        split: String, name of split.
        input_size: Integer, length of the side of the input image.
    """
    # Load fine-tuned model and convert it to be fully convolutional.
    model = get_finetune_model(arch=arch,
                               dataset=dataset,
                               checkpoint_path=checkpoint_path,
                               convert_to_fully_convolutional=True)
    device = get_device()
    model = model.to(device)

    # Prepare data augmentation.
    assert (isinstance(input_size, int))
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    transform = transforms.Compose([
        transforms.Resize(input_size),
        transforms.ToTensor(),
        normalize,
    ])

    # Prepare data loaders.
    if 'voc' in dataset:
        year = dataset.split('_')[-1]

        target_transform = transforms.Compose([
            FromVOCToOneHotEncoding(),
            SimpleToTensor(),
        ])

        dset = datasets.VOCDetection(data_dir,
                                     year=year,
                                     image_set=split,
                                     transform=transform,
                                     target_transform=target_transform)
    elif 'coco' in dataset:
        ann_path = os.path.join(ann_dir, 'instances_%s.json' % split)
        target_transform = transforms.Compose([
            FromCocoToOneHotEncoding(),
            SimpleToTensor(),
        ])
        dset = datasets.CocoDetection(os.path.join(data_dir, split),
                                      ann_path,
                                      transform=transform,
                                      target_transform=target_transform)
    else:
        assert (False)

    loader = torch.utils.data.DataLoader(dset,
                                         batch_size=1,
                                         num_workers=1,
                                         shuffle=False,
                                         pin_memory=True)

    for i, (x, y) in enumerate(loader):
        # Move data to device.
        x = x.to(device)
        y = y.to(device)

        # Compute forward pass.
        pred_y = model(x)

        # Verify shape.
        assert (y.shape[0] == 1)
        assert (len(y.shape) == 2)
        assert (len(pred_y.shape) == 4)

        # Get present classes in image.
        class_idx = np.where(y[0].cpu().data.numpy())[0]

        # Compute a mask for each present class in the image.
        for c in class_idx:
            # Match fully convolutional output shape.
            class_y = torch.zeros_like(pred_y)
            class_y[0, c, :, :] = 1

            # Gradient signal.
            grad_signal = pred_y * class_y

            # TODO: Compute mask.
            pass
import torch
import torchvision.datasets as datasets
data = datasets.VOCDetection('datasets/',
                             year='2007',
                             image_set='val',
                             download=True)
Ejemplo n.º 27
0
def get_voc_data(dataset='seg',
                 batch_size=64,
                 test_batch_size=1,
                 year='2008',
                 root='../data/VOCdevkit',
                 download=False):
    shuffle = False
    kwargs = {}

    transformations = transforms.Compose([
        transforms.Resize(255),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        #    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        #)
    ])

    def proc_img(img):
        arr = np.array(img)
        arr.dtype = np.int8
        return arr

    trgt_transformations = transforms.Compose([
        transforms.Resize(255),
        transforms.CenterCrop(224),
        transforms.Lambda(proc_img),
        transforms.ToTensor()
    ])

    dataset_train = None
    dataset_test = None

    if dataset == 'seg':
        dataset_train = datasets.VOCSegmentation(
            root,
            year=year,
            image_set="train",
            download=download,
            transform=transformations,
            target_transform=trgt_transformations)
        dataset_test = datasets.VOCSegmentation(
            root,
            year=year,
            image_set="val",
            download=download,
            transform=transformations,
            target_transform=trgt_transformations)

    elif dataset == 'det':
        dataset_train = datasets.VOCDetection(
            root,
            year=year,
            image_set="train",
            download=download,
            transform=transformations,
            target_transform=trgt_transformations)
        dataset_test = datasets.VOCDetection(
            root,
            year=year,
            image_set="val",
            download=download,
            transform=transformations,
            target_transform=trgt_transformations)

    train_loader = torch.utils.data.DataLoader(dataset_train,
                                               batch_size=batch_size,
                                               shuffle=shuffle,
                                               num_workers=4,
                                               **kwargs)

    test_loader = torch.utils.data.DataLoader(dataset_test,
                                              batch_size=test_batch_size,
                                              num_workers=4,
                                              shuffle=shuffle,
                                              **kwargs)

    return train_loader, test_loader
Ejemplo n.º 28
0
import torch
import torch.nn as nn
from torchvision import datasets, transforms

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

transform = transforms.Compose([
    transforms.Resize((240, 240)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307, ), (0.3081, ))
])

dataset = datasets.VOCDetection("/datasets",
                                year='2012',
                                image_set='trainval',
                                download=True,
                                transform=transform)

dataloader = torch.utils.data.DataLoader(dataset=dataset,
                                         batch_size=1,
                                         shuffle=True,
                                         num_workers=1)

X, y = next(iter(dataloader))
print(X.shape)
print(y)

for X, y in iter(dataloader):
    print(X.shape)
    if isinstance(y['annotation']['object'], list) == True:
Ejemplo n.º 29
0
def pointing_game(data_dir,
                  checkpoint_path,
                  out_path=None,
                  save_dir=None,
                  load_from_save_dir=False,
                  arch='vgg16',
                  converted_caffe=False,
                  dataset='voc_2007',
                  ann_dir=None,
                  split='test',
                  metric='pointing',
                  input_size=224,
                  vis_method='gradient',
                  final_gap_layer=False,
                  debug=False,
                  print_iter=1,
                  save_iter=25,
                  start_index=-1,
                  end_index=-1,
                  layer_name=None,
                  eps=1e-6,
                  num_masks=4000,
                  s=7,
                  p1=0.5,
                  rise_filter_path=None,
                  gpu_batch=100):
    """
    Play the pointing game using a finetuned model and visualization method.

    Args:
        data_dir: String, root directory for dataset.
        checkpoint_path: String, path to model checkpoint.
        out_path: String, path to save per-image results to.
        save_dir: String, path to directory to save per-image visualizations.
        arch: String, name of torchvision.models architecture.
        converted_caffe: Boolean, if True, use weights converted from Caffe.
        dataset: String, name of dataset.
        ann_dir: String, path to root directory containing annotation files
            (used for COCO).
        split: String, name of split to use for evaluation.
        input_size: Integer, length of side of the input image.
        vis_method: String, visualization method.
        tolerance: Integer, number of pixels for tolerance margin.
        smooth_sigma: Float, sigma with which to scale Gaussian kernel.
        final_gap_layer: Boolean, if True, add a final gap layer.
        debug: Boolean, if True, show debug visualizations.
        print_iter: Integer, frequency with which to log messages.
        eps: Float, epsilon value to add to denominator for division.

    Returns:
        (avg_acc, acc): Tuple containing the following:
            avg_acc: Float, pointing game accuracy over all classes,
            acc: ndarray, array containing accuracies for each class.
    """
    if metric == 'pointing':
        tolerance = 15
    elif metric == 'average_precision':
        tolerance = 0
    else:
        assert False

    if vis_method in ['gradient', 'guided_backprop']:
        smooth_sigma = 0.02
    else:
        smooth_sigma = 0

    if debug:
        viz = visdom.Visdom(env=f'pointing_caffe_{converted_caffe}')

    # Load fine-tuned model with weights and convert to be fully convolutional.
    model = get_finetune_model(arch=arch,
                               dataset=dataset,
                               converted_caffe=converted_caffe,
                               checkpoint_path=checkpoint_path,
                               convert_to_fully_convolutional=True,
                               final_gap_layer=final_gap_layer)

    # Handle large images on CPU.
    cpu_device = torch.device('cpu')

    # Handle all other images on GPU, if available.
    device = get_device()

    model = model.to(device)

    # 'guided_backprop' as in Springenberg et al., ICLR Workshop 2015.
    if vis_method == 'guided_backprop':
        # Change backwards function for ReLU.
        def guided_hook_function(module, grad_in, grad_out):
            return (torch.clamp(grad_in[0], min=0.0), )

        register_hook_on_module(curr_module=model,
                                module_type=nn.ReLU,
                                hook_func=guided_hook_function,
                                hook_direction='backward')
    # 'cam' as in Zhou et al., CVPR 2016.
    elif vis_method == 'cam':
        if 'resnet' in arch:
            # Get third to last layer.
            layer_name = '%d' % (len(list(model.children())) - 3)
            layer_names = [layer_name]
        elif 'googlenet' in arch:
            # Get second to last layer (exclude GAP and last fc layer).
            layer_name = '%d' % (len(list(model.children())) - 2)
            layer_names = [layer_name]
        else:
            assert False
        last_layer = list(model.children())[-1]
        assert (isinstance(last_layer, nn.Conv2d))
        weights = last_layer.state_dict()['weight']
        assert (len(weights.shape) == 4)
        assert (weights.shape[2] == 1 and weights.shape[3] == 1)
    elif vis_method == 'grad_cam':
        if 'vgg16' in arch:
            if layer_name is not None:
                layer_names = [layer_name]
            else:
                layer_names = [
                    '29'
                ]  # last conv layer in features (pre-pooling) (14 x 14)
        elif 'resnet50' in arch:
            if layer_name is not None:
                layer_names = [layer_name]
            else:
                layer_names = ['7']  # last conv layer before GAP and FC layer.
        elif 'googlenet' in arch:
            if layer_name is not None:
                layer_names = [layer_name]
            else:
                layer_names = ['15'
                               ]  # last conv layer before GAP and FC layer.
        else:
            assert (False)
        # Prepare to get backpropagated gradient at intermediate layer.
        grads = []

        def hook_grads(module, grad_in, grad_out):
            grads.append(grad_in)

        assert (len(layer_names) == 1)
        hook = get_pytorch_module(
            model, layer_names[0]).register_backward_hook(hook_grads)
    elif vis_method == 'rise':
        explainer = RISE(model, input_size, device=device, gpu_batch=gpu_batch)
        if rise_filter_path is None or not os.path.exists(rise_filter_path):
            if rise_filter_path is None:
                rise_filter_path = 'masks.npy'
            create_dir_if_necessary(rise_filter_path)
            explainer.generate_masks(N=num_masks,
                                     s=s,
                                     p1=p1,
                                     savepath=rise_filter_path)
        else:
            explainer.load_masks(filepath=rise_filter_path)

    # Prepare data augmentation.
    assert (isinstance(input_size, int))
    if vis_method == 'rise':
        resize_transform = transforms.Resize((input_size, input_size))
    else:
        resize_transform = transforms.Resize(input_size)
    if converted_caffe:
        if vis_method == 'rise':
            transform = get_caffe_transform(size=(input_size, input_size))
        else:
            transform = get_caffe_transform(size=input_size)
    else:
        normalize_transform = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                   std=[0.229, 0.224, 0.225])
        transform = transforms.Compose([
            resize_transform,
            transforms.ToTensor(),
            normalize_transform,
        ])

    if 'voc' in dataset:
        target_transform = transforms.Compose([
            FromVOCToDenseBoundingBoxes(tolerance=tolerance),
            SimpleResize(input_size),
            SimpleToTensor(),
        ])

        num_classes = 20
        year = dataset.split('_')[-1]
        dset = datasets.VOCDetection(data_dir,
                                     year=year,
                                     image_set=split,
                                     transform=transform,
                                     target_transform=target_transform)
    elif 'coco' in dataset:
        num_classes = 80
        print(ann_dir)
        ann_path = os.path.join(ann_dir, 'instances_%s.json' % split)

        dset = datasets.CocoDetection(os.path.join(data_dir, split),
                                      ann_path,
                                      transform=transform,
                                      target_transform=None)

        target_transform = transforms.Compose([
            FromCocoToDenseSegmentationMasks(dset.coco, tolerance=tolerance),
            SimpleResize(input_size),
            SimpleToTensor(),
        ])

        dset.target_transform = target_transform
    elif 'imnet' in dataset:
        num_classes = 1000
        dset = SortedFolder(os.path.join(data_dir, split), transform=transform)
    else:
        assert (False)

    print(f'dataset: {dataset}\n'
          f'split: {split}\n'
          f'arch: {arch}\n'
          f'metric: {metric}\n'
          f'smooth_sigma: {smooth_sigma}\n'
          f'tolerance: {tolerance}\n'
          f'checkpoint_path: {checkpoint_path}\n'
          f'out_path: {out_path}\n'
          f'save_dir: {save_dir}\n')
    print('Number of examples in dataset split: %d' % len(dset))
    if start_index != -1 or end_index != -1:
        if end_index == -1:
            end_index = len(dset)
        if start_index == -1:
            start_index = 0
        idx = range(start_index, end_index)
        dset = torch.utils.data.Subset(dset, idx)
        print(f'Evaluating from {start_index} to {end_index}')
    else:
        start_index = 0
        end_index = len(dset)

    if save_dir is not None:
        create_dir_if_necessary(save_dir, is_dir=True)

    # Prepare to evaluate pointing game.
    if out_path is not None:
        if os.path.exists(out_path):
            print('Loading previous records...')
            records = np.loadtxt(out_path)
        else:
            records = np.zeros((len(dset), num_classes))
    if metric == 'pointing':
        hits = np.zeros(num_classes)
        misses = np.zeros(num_classes)
    elif metric == 'average_precision':
        sum_precs = np.zeros(num_classes)
        num_examples = np.zeros(num_classes)

    if out_path is not None and np.sum(records) != 0:
        if metric == 'pointing':
            hits = np.sum(records == 1, 0)
            misses = np.sum(records == -1, 0)
        elif metric == 'average_precision':
            sum_precs = np.sum(records, 0)
            num_examples = np.sum(records != 0, 0)
        next_index = np.where(records != 0)[0][-1] + 1
        print(f'Next Index {next_index}')
    else:
        next_index = 0

    if next_index > 0:
        dset = torch.utils.data.Subset(dset, range(next_index, len(dset)))
        start_index = next_index

    loader = torch.utils.data.DataLoader(dset, batch_size=1, shuffle=False)

    image_idx = []
    y_shapes = []
    vis_shapes = []
    using_cpu = False
    t_loop = tqdm.tqdm(loader)
    for i, (x, y) in enumerate(t_loop):
        if (save_dir is not None and os.path.exists(
                os.path.join(save_dir, f'{i+start_index:06d}.pth'))
                and 'imnet' in dataset):
            print(f'Skipping image {i+start_index}; already saved.')
            continue

        # Verify shape.
        assert (x.shape[0] == 1)
        assert (y.shape[0] == 1)

        # Move data to device.
        x = x.to(device)

        # Get present classes in the image.
        if 'imnet' in dataset:
            class_idx = y.numpy()
        else:
            class_idx = np.where(
                np.sum(y[0].cpu().data.numpy(), (1, 2)) > 0)[0]
        curr_num_classes = len(class_idx)
        if curr_num_classes == 0:
            print(f'Skipping image {i+start_index}; no classes in it.')
            continue
        assert (curr_num_classes >= 1)

        if vis_method != 'rise':
            # Set input batch size to the number of classes.
            x = x.expand(curr_num_classes, *x.shape[1:])
            if vis_method is not 'cam':
                x.requires_grad = True

            model.zero_grad()
            try:
                pred_y = model(x)
            except RuntimeError:
                using_cpu = True
                print(
                    f'Using CPU to handle image {i+start_index} with shape {x.shape}.'
                )
                # x = torch.tensor(x, device=cpu_device, requires_grad=True)
                x = x.cpu().clone().detach().requires_grad_(True)
                model.to(cpu_device)
                model.zero_grad()
                pred_y = model(x)

        # Play pointing game using the specified visualization method.
        # 'gradient' is Simonyan et al., ICLR Workshop 2014.
        if vis_method in ['gradient', 'guided_backprop']:

            # Prepare gradient.
            weights = torch.zeros_like(pred_y)
            labels = torch.from_numpy(class_idx).to(pred_y.device)
            labels = labels[:, None, None, None]
            labels_shape = (curr_num_classes, 1, weights.shape[2],
                            weights.shape[3])
            labels = labels.expand(*labels_shape)
            weights.scatter_(1, labels, 1)
            try:
                pred_y.backward(weights)
            except RuntimeError:
                # TODO(ruthfong): Handle with less redundancy.
                using_cpu = True
                print(
                    f'Using CPU to handle image {i+start_index} with shape {x.shape}.'
                )
                # x = torch.tensor(x, device=cpu_device, requires_grad=True)
                x = x.cpu().clone().detach().requires_grad_(True)
                model.to(cpu_device)
                model.zero_grad()
                pred_y = model(x)

                weights = torch.zeros_like(pred_y)
                labels = torch.from_numpy(class_idx).to(pred_y.device)
                labels = labels[:, None, None, None]
                labels_shape = (curr_num_classes, 1, weights.shape[2],
                                weights.shape[3])
                labels = labels.expand(*labels_shape)
                weights.scatter_(1, labels, 1)
                pred_y.backward(weights)

            # Compute gradient visualization.
            vis, _ = torch.max(torch.abs(x.grad), 1, keepdim=True)

            # Smooth gradient visualization as in Zhang et al., ECCV 2016.
            if smooth_sigma > 0:
                vis = blur_input_tensor(vis,
                                        sigma=smooth_sigma *
                                        max(vis.shape[2:]))
        elif vis_method == 'cam':
            try:
                acts = hook_get_acts(model, layer_names, x)[0]
            except RuntimeError:
                using_cpu = True
                print(
                    f'Using CPU to handle image {i+start_index} with shape {x.shape}.'
                )
                x = x.cpu().clone().detach().requires_grad_(True)
                model.to(cpu_device)
                acts = hook_get_acts(model, layer_names, x)[0]

            vis_lowres = torch.mean(acts * weights[class_idx].to(acts.device),
                                    1,
                                    keepdim=True)
            vis = nn.functional.interpolate(vis_lowres,
                                            size=y.shape[2:],
                                            mode='bilinear')
        elif vis_method == 'grad_cam':
            # Prepare gradient.
            weights = torch.zeros_like(pred_y)
            labels = torch.from_numpy(class_idx).to(pred_y.device)
            labels = labels[:, None, None, None]
            labels_shape = (curr_num_classes, 1, weights.shape[2],
                            weights.shape[3])
            labels = labels.expand(*labels_shape)
            weights.scatter_(1, labels, 1)

            # Get backpropagated gradient at intermediate layer.
            try:
                pred_y.backward(weights)
            except:
                # TODO(ruthfong): Handle with less redundancy.
                using_cpu = True
                print(
                    f'Using CPU to handle image {i+start_index} with shape {x.shape}.'
                )
                # x = torch.tensor(x, device=cpu_device, requires_grad=True)
                x = x.cpu().clone().detach().requires_grad_(True)
                model.to(cpu_device)
                model.zero_grad()
                pred_y = model(x)

                weights = torch.zeros_like(pred_y)
                labels = torch.from_numpy(class_idx).to(pred_y.device)
                labels = labels[:, None, None, None]
                labels_shape = (curr_num_classes, 1, weights.shape[2],
                                weights.shape[3])
                labels = labels.expand(*labels_shape)
                weights.scatter_(1, labels, 1)
                pred_y.backward(weights)

            assert (len(grads) == 1)
            if len(grads[0]) == 1:
                grad = grads[0][0]
            else:
                assert ('googlenet' in arch)
                grad = torch.cat(grads[0], 1)
            grad = grad.to(pred_y.device)
            # assert(len(grads[0]) == 1)
            #grad = grads[0][-1]
            del grads[:]
            #hook.remove()

            # Get activations at intermediate layer.
            acts = hook_get_acts(model, layer_names, x)[0]

            # Apply global average pooling to intermediate gradient.
            grad_weights = torch.mean(grad, (2, 3), keepdim=True)

            # Linearly combine activations and gradient weights.
            try:
                grad_cam = torch.sum(acts * grad_weights, 1, keepdim=True)
            except:
                import pdb
                pdb.set_trace()

            # Apply ReLU to GradCAM vis.
            vis_lowres = torch.clamp(grad_cam, min=0)

            # Upsample visualization to image size.
            vis = nn.functional.interpolate(vis_lowres,
                                            size=x.shape[2:],
                                            mode='bilinear')

        elif vis_method == 'rise':
            if load_from_save_dir:
                try:
                    vis = torch.load(
                        os.path.join(save_dir, f'{i+start_index:06d}.pth'))
                    if isinstance(vis, torch.Tensor):
                        vis = vis[class_idx]
                    else:
                        assert isinstance(vis, dict)
                        vis = vis['vis']
                        assert np.all(class_idx == vis['class_idx'])
                except:
                    print(f'No file for {i+start_index:06d}, running RISE.')
                    vis = explainer(x)
                    vis = vis.unsqueeze(1)
                    # Upsample visualization to image size.
                    vis = nn.functional.interpolate(vis,
                                                    size=y.shape[2:],
                                                    mode='bilinear')
                    torch.save(
                        vis, os.path.join(save_dir,
                                          f'{i+start_index:06d}.pth'))

            else:
                vis = explainer(x)
                vis = vis[class_idx]
                vis = vis.unsqueeze(1)
                # Upsample visualization to image size.
                if not 'imnet' in dataset:
                    vis = nn.functional.interpolate(vis,
                                                    size=y.shape[2:],
                                                    mode='bilinear')
        else:
            assert (False)

        if save_dir is not None and not load_from_save_dir:
            save_path = os.path.join(save_dir, get_synset(class_idx))
            if not os.path.exists(save_path):
                os.makedirs(save_path)

            torch.save({
                'mask': vis,
                'class_idx': class_idx,
            },
                       os.path.join(
                           save_path, 'ILSVRC2012_val_' +
                           f'{i+1+start_index:08d}.JPEG.pth'))

        if 'imnet' in dataset:
            continue

        # Move model back to GPU if necessary.
        if using_cpu:
            model.to(device)
            using_cpu = False

        y_shape = y.shape[2:]
        vis_shape = vis.shape[2:]
        image_idx.append(i + start_index)
        y_shapes.append(y_shape)
        vis_shapes.append(vis_shape)
        if y.shape[2] != vis.shape[2] or y.shape[3] != vis.shape[3]:
            print(
                f'{i+start_index:06d}: output shape {y_shape} and vis shape {vis_shape} do not match'
            )
            continue

        for class_i, c in enumerate(class_idx):
            # Check if maximum point for class-specific visualization is
            # within one of the bounding boxes for that class.
            if metric == 'pointing':
                max_i = torch.argmax(vis[class_i])
                if y[0, c, :, :].view(-1)[max_i] > 0.5:
                    hits[c] += 1
                    if out_path is not None:
                        records[i, c] = 1
                else:
                    misses[c] += 1
                    if out_path is not None:
                        records[i, c] = -1
            elif metric == 'average_precision':
                # Flatten visualization and ground truth data.
                y_flat = y[0, c].reshape(-1).float()
                vis_flat = vis[class_i].reshape(-1).cpu().data.numpy()
                ap = average_precision_score(y_flat, vis_flat)
                sum_precs[c] += ap
                num_examples[c] += 1
                if out_path is not None:
                    records[i, c] = ap
            else:
                assert (False)
            if debug:

                def normalize_arr(x):
                    x_min, x_max = np.min(x), np.max(x)
                    return (x - x_min) / (x_max - x_min)

                import matplotlib
                matplotlib.use('Agg')
                import matplotlib.pyplot as plt
                if converted_caffe:
                    viz.image(vutils.make_grid(CaffeChannelSwap()(
                        x[0]).unsqueeze(0),
                                               normalize=True),
                              win=0)
                else:
                    viz.image(vutils.make_grid(x, normalize=True), win=0)
                viz.image(vutils.make_grid(vis[class_i], normalize=True),
                          win=1)
                # time.sleep(1)
                f, ax = plt.subplots(1, 1)
                if converted_caffe:
                    ax.imshow(
                        vutils.make_grid(CaffeChannelSwap()(x[0]).unsqueeze(0),
                                         normalize=True).cpu().data.squeeze().
                        numpy().transpose(1, 2, 0))
                else:
                    ax.imshow(
                        vutils.make_grid(x, normalize=True).cpu().data.squeeze(
                        ).numpy().transpose(1, 2, 0))
                ax.imshow(resize(
                    normalize_arr(vis[class_i].cpu().data.numpy().transpose(
                        1, 2, 0)), x.shape[2:]).squeeze(),
                          alpha=0.5,
                          cmap='jet')
                create_dir_if_necessary(os.path.join(save_dir, 'debug_images'),
                                        True)
                plt.savefig(
                    os.path.join(save_dir, 'debug_images',
                                 f'{i+start_index:06d}_class_{class_i}.png'))
                plt.close()
                # print(np.argmax(y[0].cpu().data.numpy()))

        if i % print_iter == 0:
            if metric == 'pointing':
                running_avg = np.mean(hits / (hits + misses + eps))
                metric_name = 'Avg Acc'
            elif metric == 'average_precision':
                running_avg = np.mean(sum_precs / (num_examples + eps))
                metric_name = 'Mean Avg Prec'
            t_loop.set_description(f'{metric_name} {running_avg:.4f}')
            if debug:
                pass
                viz.image(vutils.make_grid(x[0].unsqueeze(0), normalize=True),
                          0)
                viz.image(vutils.make_grid(vis, normalize=True), 1)
        if i % save_iter == 0 and out_path is not None:
            create_dir_if_necessary(out_path)
            np.savetxt(out_path, records)
            torch.save(
                {
                    'image_idx': image_idx,
                    'vis_shapes': vis_shapes,
                    'y_shapes': y_shapes
                }, f'errors_new_v2_{vis_method}_{dataset}_{metric}.pth')

    torch.save(
        {
            'image_idx': image_idx,
            'vis_shapes': vis_shapes,
            'y_shapes': y_shapes
        }, f'errors_new_v2_{vis_method}_{dataset}_{metric}.pth')
    if out_path is not None:
        create_dir_if_necessary(out_path)
        np.savetxt(out_path, records)
    if metric == 'pointing':
        acc = hits / (hits + misses)
        avg_acc = np.mean(acc)
        print('Avg Acc: %.4f' % avg_acc)
        print(acc)
        return avg_acc, acc
    elif metric == 'average_precision':
        class_mean_avg_prec = sum_precs / num_examples
        mean_avg_prec = np.mean(class_mean_avg_prec)
        print('Mean Avg Prec: %.4f' % mean_avg_prec)
        print(class_mean_avg_prec)
        return mean_avg_prec, class_mean_avg_prec

    if out_path is not None:
        compute_metrics(out_path, metric=metric, dataset=dataset)