def __init__(self, size: torch.Size, root: str = None, limit_var=np.infty, logger: Logger = None): """ Outlier Exposure dataset for ImageNet22k. :param size: size of the samples in n x c x h x w, samples will be resized to h x w. If n is larger than the number of samples available in ImageNet22k, dataset will be enlarged by repetitions to fit n. This is important as exactly n images are extracted per iteration of the data_loader. For online supervision n should be set to 1 because only one sample is extracted at a time. :param root: root directory where data is found or is to be downloaded to. :param limit_var: limits the number of different samples, i.e. randomly chooses limit_var many samples. from all available ones to be the training data. :param logger: logger """ assert len(size) == 4 and size[2] == size[3] assert size[1] in [1, 3] root = pt.join(root, 'imagenet22k') if not root.endswith('imagenet') else pt.join(root, '..', 'imagenet22k') root = pt.join(root, 'fall11_whole_extracted') # important to have a second layer, to speed up load meta file self.root = root self.logger = logger with logger.timeit('Loading ImageNet22k'): super().__init__(root=root, size=size, logger=logger) self.transform = transforms.Compose([ transforms.Resize(size[2]), transforms.ToTensor() ]) self.picks = None if limit_var is not None and limit_var < len(self): self.picks = np.random.choice(len(self.samples), size=limit_var, replace=False) if limit_var is not None and limit_var > len(self): self.logprint( 'OEImageNet22 shall be limited to {} samples, but ImageNet22k contains only {} samples, thus using all.' .format(limit_var, len(self)), fps=False ) if len(self) < size[0]: raise NotImplementedError()
def trainer_setup( dataset: str, datadir: str, logdir: str, net: str, bias: bool, learning_rate: float, weight_decay: float, lr_sched_param: List[float], batch_size: int, optimizer_type: str, scheduler_type: str, objective: str, preproc: str, supervise_mode: str, nominal_label: int, online_supervision: bool, oe_limit: int, noise_mode: str, workers: int, quantile: float, resdown: int, gauss_std: float, blur_heatmaps: bool, cuda: bool, config: str, log_start_time: int = None, normal_class: int = 0, ) -> dict: """ Creates a complete setup for training, given all necessary parameter from a runner (seefcdd.runners.bases.py). This includes loading networks, datasets, data loaders, optimizers, and learning rate schedulers. :param dataset: dataset identifier string (see :data:`fcdd.datasets.DS_CHOICES`). :param datadir: directory where the datasets are found or to be downloaded to. :param logdir: directory where log data is to be stored. :param net: network model identifier string (see :func:`fcdd.models.choices`). :param bias: whether to use bias in the network layers. :param learning_rate: initial learning rate. :param weight_decay: weight decay (L2 penalty) regularizer. :param lr_sched_param: learning rate scheduler parameters. Format depends on the scheduler type. For 'milestones' needs to have at least two elements, the first corresponding to the factor the learning rate is decreased by at each milestone, the rest corresponding to milestones (epochs). For 'lambda' needs to have exactly one element, i.e. the factor the learning rate is decreased by at each epoch. :param batch_size: batch size, i.e. number of data samples that are returned per iteration of the data loader. :param optimizer_type: optimizer type, needs to be one of {'sgd', 'adam'}. :param scheduler_type: learning rate scheduler type, needs to be one of {'lambda', 'milestones'}. :param objective: the training objective. See :data:`OBJECTIVES`. :param preproc: data preprocessing pipeline identifier string (see :data:`fcdd.datasets.PREPROC_CHOICES`). :param supervise_mode: the type of generated artificial anomalies. See :meth:`fcdd.datasets.bases.TorchvisionDataset._generate_artificial_anomalies_train_set`. :param nominal_label: the label that is to be returned to mark nominal samples. :param online_supervision: whether to sample anomalies online in each epoch, or offline before training (same for all epochs in this case). :param oe_limit: limits the number of different anomalies in case of Outlier Exposure (defined in noise_mode). :param noise_mode: the type of noise used, see :mod:`fcdd.datasets.noise_mode`. :param workers: how many subprocesses to use for data loading. :param quantile: the quantile that is used to normalize the generated heatmap images. :param resdown: the maximum resolution of logged images, images will be downsampled if necessary. :param gauss_std: a constant value for the standard deviation of the Gaussian kernel used for upsampling and blurring, the default value is determined by :func:`fcdd.datasets.noise.kernel_size_to_std`. :param blur_heatmaps: whether to blur heatmaps. :param cuda: whether to use GPU. :param config: some config text that is to be stored in the config.txt file. :param log_start_time: the start time of the experiment. :param normal_class: the class that is to be considered nominal. :return: a dictionary containing all necessary parameters to be passed to a Trainer instance. """ assert objective in OBJECTIVES, 'unknown objective: {}'.format(objective) assert supervise_mode in SUPERVISE_MODES, 'unknown supervise mode: {}'.format( supervise_mode) assert noise_mode in MODES, 'unknown noise mode: {}'.format(noise_mode) device = torch.device('cuda:0') if cuda else torch.device('cpu') logger = Logger(pt.abspath(pt.join(logdir, '')), exp_start_time=log_start_time) ds = load_dataset(dataset, pt.abspath(pt.join(datadir, '')), normal_class, preproc, supervise_mode, noise_mode, online_supervision, nominal_label, oe_limit, logger=logger) loaders = ds.loaders(batch_size=batch_size, num_workers=workers) net = load_nets(net, ds.shape, bias=bias) logger.logtxt('##### NET RECEPTION {} #####'.format( net.reception if hasattr(net, 'reception') else None), print=True) net = net.to(device) optimizer, scheduler = pick_opt_sched(net, learning_rate, weight_decay, lr_sched_param, optimizer_type, scheduler_type) logger.save_params(net, config) if not hasattr(ds, 'nominal_label') or ds.nominal_label < ds.anomalous_label: ds_order = ['norm', 'anom'] else: ds_order = ['anom', 'norm'] images = ds.preview(20) logger.imsave('ds_preview', torch.cat([*images]), nrow=images.size(1), rowheaders=ds_order if not isinstance(ds.train_set, GTMapADDataset) else [ *ds_order, '', *['gtno' if s == 'norm' else 'gtan' for s in ds_order] ]) return { 'net': net, 'dataset_loaders': loaders, 'opt': optimizer, 'sched': scheduler, 'logger': logger, 'device': device, 'objective': objective, 'quantile': quantile, 'resdown': resdown, 'gauss_std': gauss_std, 'blur_heatmaps': blur_heatmaps }
def combine_specific_viz_ids_pics(srcs: List[str], out: str = None, setup: List[str] = ('base', 'hsc', 'ae'), skip_further=False, only_cls: List[int] = None): """ Combines heatmap images (visualization ids) for several old experiments for the same input images. Depending on the setup, it creates an image with input images at the top and heatmap images below, where each row corresponds to one experiment and each column to one input. A row can also contain ground-truth heatmaps. The combined heatmap images are stored on the disk according to the out parameter. :param srcs: paths to root directories of old experiments :param out: directory in which to put the combined images (class and seed-wise) :param setup: types of experiments/rows, need to be in the order of srcs, each element has to be in OPTIONS. "base": FCDD experiment, always needs to be the first element of setup! "hsc": HSC experiment with gradient heatmaps. "ae": Autoencoder experiment with reconstruction loss heatmaps. "gts": Ground-truth heatmaps. :param skip_further: if an experiment has more than one type of heatmap images, i.e. its logged images contain more than 2 rows (first row is always input), consider only the first type of heatmap. :param only_cls: list of classes, classes not part of the list are skipped, None means no classes are skipped :return: """ # TODO get rid of setup? assert all([s in OPTIONS for s in setup]) assert setup[0] == 'base' if 'gts' in setup: assert setup[-1] == 'gts' if out is None: out = srcs[0] + '_COMBINED_PAPER_PICS' if len(srcs) != len(setup): raise ValueError('fixed len of src required, {}, but found {}!'.format( ' '.join(['({}) {}'.format(i + 1, s) for i, s in enumerate(setup)]), len(srcs))) pics = {} for n, src in enumerate(srcs): cls_labels = [pt.join(src, c) for c in os.listdir(src)] cls_labels.sort(key=pt.getmtime) cls_labels = [pt.basename(c) for c in cls_labels] if all([ c.startswith('it_') for c in cls_labels if pt.isdir(pt.join(src, c)) ]): # one class experiment cls_labels = ['.'] for cls_dir in cls_labels: if not pt.isdir(pt.join(src, cls_dir)): continue assert cls_dir.startswith('normal_') if only_cls is not None and len(only_cls) > 0 and int( cls_dir[7:]) not in only_cls: continue print('collecting pictures of {} {}...'.format(src, cls_dir)) for it_dir in os.listdir(pt.join(src, cls_dir)): if pt.isfile(pt.join(src, cls_dir, it_dir)): continue cfg = read_cfg(pt.join(src, cls_dir, it_dir, 'config.txt')) tims_dir = pt.join(src, cls_dir, it_dir, 'tims') if n == 0: if pt.exists(pt.join(tims_dir, 'specific_viz_ids')): raise ValueError( 'First src should not contains specific viz ids, as first src should be the base!' ) for root, dirs, files in os.walk(tims_dir): for f in files: assert f[-4:] == '.pth' if cls_dir not in pics: pics[cls_dir] = {} if it_dir not in pics[cls_dir]: pics[cls_dir][it_dir] = {} pics[cls_dir][it_dir][f[:-4]] = [ torch.load(pt.join(root, f)) ] else: if not pt.exists(pt.join(tims_dir, 'specific_viz_ids')): raise ValueError( 'Src {} should contain specific viz ids, but it doesnt!' .format(src)) for root, dirs, files in os.walk( pt.join(tims_dir, 'specific_viz_ids')): for f in files: assert f[-4:] == '.pth' if cls_dir == '.' and cls_dir not in pics: warnings.warn( 'Seems that src {} is a one class experiment...' .format(src)) cls = 'normal_{}'.format(cfg['normal_class']) else: cls = cls_dir if cls not in pics or it_dir not in pics[cls]: raise ValueError( '{} {} is missing in base src!!'.format( cls_dir, it_dir)) if setup[n] in ('ae', ): if not f.startswith('ae_'): continue pics[cls][it_dir][f[3:-4]].append( torch.load(pt.join(root, f))) else: if f.startswith('ae_'): raise ValueError( 'ae has been found in position {}, but shouldnt be!' .format(n)) pics[cls][it_dir][f[:-4]].append( torch.load(pt.join(root, f))) logger = Logger(out) for cls_dir in pics: print('creating pictures for {} {}...'.format(out, cls_dir)) for it_dir in pics[cls_dir]: for file in pics[cls_dir][it_dir]: combined_pic = [] inps = [] gts = None tensors = pics[cls_dir][it_dir][file] if len(tensors) != len(srcs): print( 'Some specific viz id tims are missing for {} {}!! Skipping them...' .format(cls_dir, it_dir), file=sys.stderr) continue # 0 == base src t = tensors[0] rows, cols, c, h, w = t.shape inps.append(t[0]) if 'gts' in setup: combined_pic.extend([*t[:2 if skip_further else -1]]) gts = t[-1] else: combined_pic.extend( [*t[:2 if skip_further else 10000000000]]) for t in tensors[1:]: rows, cols, c, h, w = t.shape if rows == 3: # assume gts in final row t = t[:-1] inps.append(t[0]) combined_pic.append(t[1]) # ADD GTMAP if gts is not None: combined_pic.append(gts) # check of all inputs have been the same for i, s in enumerate(srcs): for j, ss in enumerate(srcs): if j <= i: continue if (inps[i] != inps[j]).sum() > 0: raise ValueError( 'SRC {} and SRC {} have different inputs!!!'. format(srcs[i], srcs[j])) # combine new_cols = combined_pic[0].size(0) tim = torch.cat(combined_pic) logger.imsave(file, tim, nrow=new_cols, scale_mode='none', suffix=pt.join(cls_dir, it_dir)) print('Successfully combined pics in {}.'.format(out))