Пример #1
0
def trainer_setup(
    dataset: str,
    datadir: str,
    logdir: str,
    net: str,
    bias: bool,
    learning_rate: float,
    weight_decay: float,
    lr_sched_param: List[float],
    batch_size: int,
    optimizer_type: str,
    scheduler_type: str,
    objective: str,
    preproc: str,
    supervise_mode: str,
    nominal_label: int,
    online_supervision: bool,
    oe_limit: int,
    noise_mode: str,
    workers: int,
    quantile: float,
    resdown: int,
    gauss_std: float,
    blur_heatmaps: bool,
    cuda: bool,
    config: str,
    log_start_time: int = None,
    normal_class: int = 0,
) -> dict:
    """
    Creates a complete setup for training, given all necessary parameter from a runner (seefcdd.runners.bases.py).
    This includes loading networks, datasets, data loaders, optimizers, and learning rate schedulers.
    :param dataset: dataset identifier string (see :data:`fcdd.datasets.DS_CHOICES`).
    :param datadir: directory where the datasets are found or to be downloaded to.
    :param logdir: directory where log data is to be stored.
    :param net: network model identifier string (see :func:`fcdd.models.choices`).
    :param bias: whether to use bias in the network layers.
    :param learning_rate: initial learning rate.
    :param weight_decay: weight decay (L2 penalty) regularizer.
    :param lr_sched_param: learning rate scheduler parameters. Format depends on the scheduler type.
        For 'milestones' needs to have at least two elements, the first corresponding to the factor
        the learning rate is decreased by at each milestone, the rest corresponding to milestones (epochs).
        For 'lambda' needs to have exactly one element, i.e. the factor the learning rate is decreased by
        at each epoch.
    :param batch_size: batch size, i.e. number of data samples that are returned per iteration of the data loader.
    :param optimizer_type: optimizer type, needs to be one of {'sgd', 'adam'}.
    :param scheduler_type: learning rate scheduler type, needs to be one of {'lambda', 'milestones'}.
    :param objective: the training objective. See :data:`OBJECTIVES`.
    :param preproc: data preprocessing pipeline identifier string (see :data:`fcdd.datasets.PREPROC_CHOICES`).
    :param supervise_mode: the type of generated artificial anomalies.
        See :meth:`fcdd.datasets.bases.TorchvisionDataset._generate_artificial_anomalies_train_set`.
    :param nominal_label: the label that is to be returned to mark nominal samples.
    :param online_supervision: whether to sample anomalies online in each epoch,
        or offline before training (same for all epochs in this case).
    :param oe_limit: limits the number of different anomalies in case of Outlier Exposure (defined in noise_mode).
    :param noise_mode: the type of noise used, see :mod:`fcdd.datasets.noise_mode`.
    :param workers: how many subprocesses to use for data loading.
    :param quantile: the quantile that is used to normalize the generated heatmap images.
    :param resdown: the maximum resolution of logged images, images will be downsampled if necessary.
    :param gauss_std: a constant value for the standard deviation of the Gaussian kernel used for upsampling and
        blurring, the default value is determined by :func:`fcdd.datasets.noise.kernel_size_to_std`.
    :param blur_heatmaps: whether to blur heatmaps.
    :param cuda: whether to use GPU.
    :param config: some config text that is to be stored in the config.txt file.
    :param log_start_time: the start time of the experiment.
    :param normal_class: the class that is to be considered nominal.
    :return: a dictionary containing all necessary parameters to be passed to a Trainer instance.
    """
    assert objective in OBJECTIVES, 'unknown objective: {}'.format(objective)
    assert supervise_mode in SUPERVISE_MODES, 'unknown supervise mode: {}'.format(
        supervise_mode)
    assert noise_mode in MODES, 'unknown noise mode: {}'.format(noise_mode)
    device = torch.device('cuda:0') if cuda else torch.device('cpu')
    logger = Logger(pt.abspath(pt.join(logdir, '')),
                    exp_start_time=log_start_time)
    ds = load_dataset(dataset,
                      pt.abspath(pt.join(datadir, '')),
                      normal_class,
                      preproc,
                      supervise_mode,
                      noise_mode,
                      online_supervision,
                      nominal_label,
                      oe_limit,
                      logger=logger)
    loaders = ds.loaders(batch_size=batch_size, num_workers=workers)
    net = load_nets(net, ds.shape, bias=bias)
    logger.logtxt('##### NET RECEPTION {} #####'.format(
        net.reception if hasattr(net, 'reception') else None),
                  print=True)
    net = net.to(device)
    optimizer, scheduler = pick_opt_sched(net, learning_rate, weight_decay,
                                          lr_sched_param, optimizer_type,
                                          scheduler_type)
    logger.save_params(net, config)
    if not hasattr(ds,
                   'nominal_label') or ds.nominal_label < ds.anomalous_label:
        ds_order = ['norm', 'anom']
    else:
        ds_order = ['anom', 'norm']
    images = ds.preview(20)
    logger.imsave('ds_preview',
                  torch.cat([*images]),
                  nrow=images.size(1),
                  rowheaders=ds_order
                  if not isinstance(ds.train_set, GTMapADDataset) else [
                      *ds_order, '',
                      *['gtno' if s == 'norm' else 'gtan' for s in ds_order]
                  ])
    return {
        'net': net,
        'dataset_loaders': loaders,
        'opt': optimizer,
        'sched': scheduler,
        'logger': logger,
        'device': device,
        'objective': objective,
        'quantile': quantile,
        'resdown': resdown,
        'gauss_std': gauss_std,
        'blur_heatmaps': blur_heatmaps
    }
Пример #2
0
def combine_specific_viz_ids_pics(srcs: List[str],
                                  out: str = None,
                                  setup: List[str] = ('base', 'hsc', 'ae'),
                                  skip_further=False,
                                  only_cls: List[int] = None):
    """
    Combines heatmap images (visualization ids) for several old experiments for the same input images.
    Depending on the setup, it creates an image with input images at the top and heatmap images below, where
    each row corresponds to one experiment and each column to one input.
    A row can also contain ground-truth heatmaps.
    The combined heatmap images are stored on the disk according to the out parameter.
    :param srcs: paths to root directories of old experiments
    :param out: directory in which to put the combined images (class and seed-wise)
    :param setup: types of experiments/rows, need to be in the order of srcs, each element has to be in OPTIONS.
        "base": FCDD experiment, always needs to be the first element of setup!
        "hsc": HSC experiment with gradient heatmaps.
        "ae": Autoencoder experiment with reconstruction loss heatmaps.
        "gts": Ground-truth heatmaps.
    :param skip_further: if an experiment has more than one type of heatmap images, i.e. its logged images
        contain more than 2 rows (first row is always input), consider only the first type of heatmap.
    :param only_cls: list of classes, classes not part of the list are skipped, None means no classes are skipped
    :return:
    """
    # TODO get rid of setup?
    assert all([s in OPTIONS for s in setup])
    assert setup[0] == 'base'
    if 'gts' in setup:
        assert setup[-1] == 'gts'

    if out is None:
        out = srcs[0] + '_COMBINED_PAPER_PICS'

    if len(srcs) != len(setup):
        raise ValueError('fixed len of src required, {}, but found {}!'.format(
            ' '.join(['({}) {}'.format(i + 1, s)
                      for i, s in enumerate(setup)]), len(srcs)))
    pics = {}
    for n, src in enumerate(srcs):
        cls_labels = [pt.join(src, c) for c in os.listdir(src)]
        cls_labels.sort(key=pt.getmtime)
        cls_labels = [pt.basename(c) for c in cls_labels]
        if all([
                c.startswith('it_') for c in cls_labels
                if pt.isdir(pt.join(src, c))
        ]):  # one class experiment
            cls_labels = ['.']
        for cls_dir in cls_labels:
            if not pt.isdir(pt.join(src, cls_dir)):
                continue
            assert cls_dir.startswith('normal_')
            if only_cls is not None and len(only_cls) > 0 and int(
                    cls_dir[7:]) not in only_cls:
                continue
            print('collecting pictures of {} {}...'.format(src, cls_dir))
            for it_dir in os.listdir(pt.join(src, cls_dir)):
                if pt.isfile(pt.join(src, cls_dir, it_dir)):
                    continue
                cfg = read_cfg(pt.join(src, cls_dir, it_dir, 'config.txt'))
                tims_dir = pt.join(src, cls_dir, it_dir, 'tims')
                if n == 0:
                    if pt.exists(pt.join(tims_dir, 'specific_viz_ids')):
                        raise ValueError(
                            'First src should not contains specific viz ids, as first src should be the base!'
                        )
                    for root, dirs, files in os.walk(tims_dir):
                        for f in files:
                            assert f[-4:] == '.pth'
                            if cls_dir not in pics:
                                pics[cls_dir] = {}
                            if it_dir not in pics[cls_dir]:
                                pics[cls_dir][it_dir] = {}
                            pics[cls_dir][it_dir][f[:-4]] = [
                                torch.load(pt.join(root, f))
                            ]
                else:
                    if not pt.exists(pt.join(tims_dir, 'specific_viz_ids')):
                        raise ValueError(
                            'Src {} should contain specific viz ids, but it doesnt!'
                            .format(src))
                    for root, dirs, files in os.walk(
                            pt.join(tims_dir, 'specific_viz_ids')):
                        for f in files:
                            assert f[-4:] == '.pth'
                            if cls_dir == '.' and cls_dir not in pics:
                                warnings.warn(
                                    'Seems that src {} is a one class experiment...'
                                    .format(src))
                                cls = 'normal_{}'.format(cfg['normal_class'])
                            else:
                                cls = cls_dir
                            if cls not in pics or it_dir not in pics[cls]:
                                raise ValueError(
                                    '{} {} is missing in base src!!'.format(
                                        cls_dir, it_dir))
                            if setup[n] in ('ae', ):
                                if not f.startswith('ae_'):
                                    continue
                                pics[cls][it_dir][f[3:-4]].append(
                                    torch.load(pt.join(root, f)))
                            else:
                                if f.startswith('ae_'):
                                    raise ValueError(
                                        'ae has been found in position {}, but shouldnt be!'
                                        .format(n))
                                pics[cls][it_dir][f[:-4]].append(
                                    torch.load(pt.join(root, f)))

    logger = Logger(out)

    for cls_dir in pics:
        print('creating pictures for {} {}...'.format(out, cls_dir))
        for it_dir in pics[cls_dir]:
            for file in pics[cls_dir][it_dir]:
                combined_pic = []
                inps = []
                gts = None
                tensors = pics[cls_dir][it_dir][file]
                if len(tensors) != len(srcs):
                    print(
                        'Some specific viz id tims are missing for {} {}!! Skipping them...'
                        .format(cls_dir, it_dir),
                        file=sys.stderr)
                    continue

                # 0 == base src
                t = tensors[0]
                rows, cols, c, h, w = t.shape
                inps.append(t[0])
                if 'gts' in setup:
                    combined_pic.extend([*t[:2 if skip_further else -1]])
                    gts = t[-1]
                else:
                    combined_pic.extend(
                        [*t[:2 if skip_further else 10000000000]])

                for t in tensors[1:]:
                    rows, cols, c, h, w = t.shape
                    if rows == 3:  # assume gts in final row
                        t = t[:-1]
                    inps.append(t[0])
                    combined_pic.append(t[1])

                # ADD GTMAP
                if gts is not None:
                    combined_pic.append(gts)

                # check of all inputs have been the same
                for i, s in enumerate(srcs):
                    for j, ss in enumerate(srcs):
                        if j <= i:
                            continue
                        if (inps[i] != inps[j]).sum() > 0:
                            raise ValueError(
                                'SRC {} and SRC {} have different inputs!!!'.
                                format(srcs[i], srcs[j]))

                # combine
                new_cols = combined_pic[0].size(0)
                tim = torch.cat(combined_pic)
                logger.imsave(file,
                              tim,
                              nrow=new_cols,
                              scale_mode='none',
                              suffix=pt.join(cls_dir, it_dir))

    print('Successfully combined pics in {}.'.format(out))