示例#1
0
 def __init__(self,
              inputs: Union[None, str, Iterable[str]] = None,
              outputs: Union[None, str, Iterable[str]] = None,
              mode: Union[None, str, Iterable[str]] = None) -> None:
     self.inputs = to_list(inputs)
     self.outputs = to_list(outputs)
     self.mode = parse_modes(to_set(mode))
示例#2
0
 def __init__(self,
              finals: Union[str, List[str]],
              outputs: Union[str, List[str]],
              inputs: Union[None, str, List[str]] = None,
              model: Union[None, tf.keras.Model, torch.nn.Module] = None,
              mode: Union[None, str, Iterable[str]] = None):
     inputs = to_list(inputs)
     finals = to_list(finals)
     outputs = to_list(outputs)
     assert bool(model) != bool(
         inputs), "Must provide either one of 'inputs' or 'model'"
     if model is None:
         assert len(inputs) == len(finals) == len(outputs), \
             "GradientOp requires the same number of inputs, finals, and outputs"
     else:
         assert isinstance(
             model,
             (tf.keras.Model, torch.nn.Module)), "Unrecognized model format"
         assert len(finals) == len(
             outputs
         ), "GradientOp requires the same number of finals, and outputs"
     inputs.extend(finals)
     super().__init__(inputs=inputs, outputs=outputs, mode=mode)
     self.model = model
     self.retain_graph = True
示例#3
0
def build(model_def, model_name, optimizer, loss_name, custom_objects=None):
    """build keras model instance in FastEstimator

    Args:
        model_def (function): function definition of tf.keras model or path of model file(h5)
        model_name (str, list, tuple): model name(s)
        optimizer (str, optimizer, list, tuple): optimizer(s)
        loss_name (str, list, tuple): loss name(s)
        custom_objects (dict): dictionary that maps custom

    Returns:
        model: model(s) compiled by FastEstimator
    """
    with fe.distribute_strategy.scope(
    ) if fe.distribute_strategy else NonContext():
        if isinstance(model_def, str):
            model = tf.keras.models.load_model(model_def,
                                               custom_objects=custom_objects)
        else:
            model = model_def()
        model = to_list(model)
        model_name = to_list(model_name)
        optimizer = to_list(optimizer)
        loss_name = to_list(loss_name)
        assert len(model) == len(model_name) == len(optimizer) == len(
            loss_name)
        for idx, (m, m_n, o, l_n) in enumerate(
                zip(model, model_name, optimizer, loss_name)):
            model[idx] = _fe_compile(m, m_n, o, l_n)
    if len(model) == 1:
        model = model[0]
    return model
示例#4
0
 def __init__(self, save_path: str, extra_objects: Any = None):
     # Verify that graphviz is available on this machine
     try:
         pydot.Dot.create(pydot.Dot())
     except OSError:
         raise OSError(
             "Traceability requires that graphviz be installed. See www.graphviz.org/download for more information.")
     # Verify that the system locale is functioning correctly
     try:
         locale.getlocale()
     except ValueError:
         raise OSError("Your system locale is not configured correctly. On mac this can be resolved by adding \
             'export LC_ALL=en_US.UTF-8' and 'export LANG=en_US.UTF-8' to your ~/.bash_profile")
     super().__init__(inputs="*", mode="!infer")  # Claim wildcard inputs to get this trace sorted last
     # Report assets will get saved into a folder for portability
     path = os.path.normpath(save_path)
     path = os.path.abspath(path)
     root_dir = os.path.dirname(path)
     report = os.path.basename(path) or 'report'
     report = report.split('.')[0]
     self.save_dir = os.path.join(root_dir, report)
     self.figure_dir = os.path.join(self.save_dir, 'resources')
     self.report_name = None  # This will be set later by the experiment name
     os.makedirs(self.save_dir, exist_ok=True)
     os.makedirs(self.figure_dir, exist_ok=True)
     # Other member variables
     self.config_tables = []
     # Extra objects will automatically get included in the report since this Trace is @traceable, so we don't need
     # to do anything with them. Referencing here to stop IDEs from flagging the argument as unused and removing it.
     to_list(extra_objects)
     self.doc = Document()
     self.log_splicer = None
示例#5
0
    def __init__(self, loss, models=None, keys=None, outputs=None):
        self.models = to_list(models) if models else []
        inputs = to_list(keys) if keys else []
        outputs = to_list(outputs) if outputs else []

        assert len(outputs) == len(inputs) + len(self.models)
        super().__init__(inputs=[loss] + inputs, outputs=outputs, mode="train")
示例#6
0
    def __init__(
        self,
        target_type: str,
        device: Optional[torch.device],
        ops: Iterable[Union[TensorOp, Scheduler[TensorOp]]],
        postprocessing: Union[None, NumpyOp, Scheduler[NumpyOp], Iterable[Union[NumpyOp, Scheduler[NumpyOp]]]] = None
    ) -> None:
        self.ops = to_list(ops)
        self.target_type = target_type
        self.device = device
        for op in get_current_items(self.ops):
            op.build(framework=self.target_type, device=self.device)
        self.models = to_list(_collect_models(ops))
        self.postprocessing = to_list(postprocessing)
        self._verify_inputs()
        self.effective_inputs = dict()
        self.effective_outputs = dict()
        self.epoch_ops = []
        self.epoch_postprocessing = []
        self.epoch_models = set()
        self.epoch_state = dict()
        self.mixed_precision = any([model.mixed_precision for model in self.models])

        if self.mixed_precision and not all([model.mixed_precision for model in self.models]):
            raise ValueError("Cannot mix full precision and mixed-precision models")
示例#7
0
 def __init__(self,
              inputs: Union[None, str, Iterable[str]] = None,
              outputs: Union[None, str, Iterable[str]] = None,
              mode: Union[None, str, Iterable[str]] = None) -> None:
     self.inputs = to_list(inputs)
     self.outputs = to_list(outputs)
     self.mode = parse_modes(to_set(mode))
     self.fe_monitor_names = set()  # The use-case here is rare enough that we don't want to add this to the init sig
示例#8
0
 def __init__(self, ops: Iterable[Union[TensorOp, Scheduler[TensorOp]]]) -> None:
     self.ops = to_list(ops)
     self.models = to_list(_collect_models(ops))
     self._verify_inputs()
     self.effective_inputs = dict()
     self.effective_outputs = dict()
     self.epoch_ops = []
     self.epoch_models = set()
     self.epoch_state = dict()
示例#9
0
 def __init__(self,
              inputs: Union[None, str, Iterable[str]] = None,
              outputs: Union[None, str, Iterable[str]] = None,
              mode: Union[None, str, Iterable[str]] = None) -> None:
     self.inputs = to_list(inputs)
     self.outputs = to_list(outputs)
     self.mode = parse_modes(to_set(mode))
     self.in_list = not isinstance(inputs, (str, type(None)))
     self.out_list = not isinstance(outputs, (str, type(None)))
示例#10
0
 def __init__(self,
              inputs: Union[str, Iterable[str]],
              outputs: Union[str, Iterable[str]],
              mode: Union[None, str, Iterable[str]] = None,
              ds_id: Union[None, str, Iterable[str]] = None):
     super().__init__(inputs=to_list(inputs),
                      outputs=to_list(outputs),
                      mode=mode,
                      ds_id=ds_id)
示例#11
0
 def __init__(self,
              inputs: Union[None, str, Iterable[str]] = None,
              outputs: Union[None, str, Iterable[str]] = None,
              mode: Union[None, str, Iterable[str]] = None,
              ds_id: Union[None, str, Iterable[str]] = None) -> None:
     self.inputs = check_io_names(to_list(inputs))
     self.outputs = check_io_names(to_list(outputs))
     self.mode = parse_modes(to_set(mode))
     self.ds_id = check_ds_id(to_set(ds_id))
     self.in_list = not isinstance(inputs, (str, type(None)))
     self.out_list = not isinstance(outputs, (str, type(None)))
示例#12
0
 def __init__(self,
              inputs: Union[str, Iterable[str]],
              outputs: Union[str, Iterable[str]],
              mode: Union[None, str, Iterable[str]] = None,
              ds_id: Union[None, str, Iterable[str]] = None,
              limit: Union[int, Tuple[int, int]] = 30):
     super().__init__(inputs=to_list(inputs),
                      outputs=to_list(outputs),
                      mode=mode,
                      ds_id=ds_id)
     self.limit = param_to_range(limit)
示例#13
0
 def __init__(self,
              inputs: Union[str, List[str]],
              outputs: Union[str, List[str]],
              indices: Union[None, str, List[str]] = None,
              mode: Union[None, str, Iterable[str]] = "eval"):
     indices = to_list(indices)
     self.num_indices = len(indices)
     combined_inputs = indices
     combined_inputs.extend(to_list(inputs))
     super().__init__(inputs=combined_inputs, outputs=outputs, mode=mode)
     self.in_list, self.out_list = True, True
 def __init__(self,
              datasets: Union[FEDataset, Iterable[FEDataset]],
              num_samples: Union[int, Iterable[int]],
              probability: Optional[Iterable[float]] = None) -> None:
     self.datasets = to_list(datasets)
     self.num_samples = to_list(num_samples)
     self.probability = to_list(probability)
     self.same_feature = False
     self._check_input()
     self.reset_index_maps()
     self.pad_value = None
示例#15
0
 def __init__(self,
              inputs: Union[str, Iterable[str]],
              outputs: Union[str, Iterable[str]],
              mode: Union[None, str, Iterable[str]] = None,
              ds_id: Union[None, str, Iterable[str]] = None,
              threshold: Union[int, Tuple[int, int], float,
                               Tuple[float, float]] = 256):
     super().__init__(inputs=to_list(inputs),
                      outputs=to_list(outputs),
                      mode=mode,
                      ds_id=ds_id)
     self.threshold = threshold
示例#16
0
 def __init__(self,
              inputs: Union[str, List[str]],
              finals: Union[str, List[str]],
              outputs: Union[str, List[str]],
              mode: Union[None, str, Iterable[str]] = None):
     inputs = to_list(inputs)
     finals = to_list(finals)
     outputs = to_list(outputs)
     assert len(inputs) == len(finals) == len(outputs), \
         "GradientOp requires the same number of inputs, finals, and outputs"
     inputs.extend(finals)
     super().__init__(inputs=inputs, outputs=outputs, mode=mode)
     self.retain_graph = True
示例#17
0
 def __init__(self, vis_steps=None, show_images=None, **plot_args):
     super().__init__()
     self.vis_steps = vis_steps
     self.plot_args = plot_args
     self.true_persist = False
     self.show_images = to_list(show_images) if show_images else []
     self.images = {}
示例#18
0
 def __init__(self,
              pipeline: Pipeline,
              network: BaseNetwork,
              epochs: int,
              train_steps_per_epoch: Optional[int] = None,
              eval_steps_per_epoch: Optional[int] = None,
              traces: Union[None, Trace, Scheduler[Trace],
                            Iterable[Union[Trace,
                                           Scheduler[Trace]]]] = None,
              log_steps: Optional[int] = 100,
              monitor_names: Union[None, str, Iterable[str]] = None):
     self.traces_in_use = []
     self.filepath = os.path.realpath(
         inspect.stack()[2].filename)  # Record this for history tracking
     assert log_steps is None or log_steps >= 0, \
         "log_steps must be None or positive (or 0 to disable only train logging)"
     self.monitor_names = to_set(monitor_names) | network.get_loss_keys()
     self.system = System(network=network,
                          pipeline=pipeline,
                          traces=to_list(traces),
                          log_steps=log_steps,
                          total_epochs=epochs,
                          train_steps_per_epoch=train_steps_per_epoch,
                          eval_steps_per_epoch=eval_steps_per_epoch,
                          system_config=self.fe_summary())
示例#19
0
 def _document_init_params(self) -> None:
     """Add initialization parameters to the traceability document.
     """
     from fastestimator.estimator import Estimator  # Avoid circular import
     with self.doc.create(Section("Parameters")):
         model_ids = {
             FEID(id(model))
             for model in self.system.network.models if isinstance(model, (tf.keras.Model, torch.nn.Module))
         }
         # Locate the datasets in order to provide extra details about them later in the summary
         datasets = {}
         for mode in ['train', 'eval', 'test']:
             objs = to_list(self.system.pipeline.data.get(mode, None))
             idx = 0
             while idx < len(objs):
                 obj = objs[idx]
                 if obj:
                     feid = FEID(id(obj))
                     if feid not in datasets:
                         datasets[feid] = ({mode}, obj)
                     else:
                         datasets[feid][0].add(mode)
                 if isinstance(obj, Scheduler):
                     objs.extend(obj.get_all_values())
                 idx += 1
         # Parse the config tables
         start = 0
         start = self._loop_tables(start,
                                   classes=(Estimator, BaseNetwork, Pipeline),
                                   name="Base Classes",
                                   model_ids=model_ids,
                                   datasets=datasets)
         start = self._loop_tables(start,
                                   classes=Scheduler,
                                   name="Schedulers",
                                   model_ids=model_ids,
                                   datasets=datasets)
         start = self._loop_tables(start, classes=Trace, name="Traces", model_ids=model_ids, datasets=datasets)
         start = self._loop_tables(start, classes=Op, name="Ops", model_ids=model_ids, datasets=datasets)
         start = self._loop_tables(start,
                                   classes=(Dataset, tf.data.Dataset),
                                   name="Datasets",
                                   model_ids=model_ids,
                                   datasets=datasets)
         start = self._loop_tables(start,
                                   classes=(tf.keras.Model, torch.nn.Module),
                                   name="Models",
                                   model_ids=model_ids,
                                   datasets=datasets)
         start = self._loop_tables(start,
                                   classes=types.FunctionType,
                                   name="Functions",
                                   model_ids=model_ids,
                                   datasets=datasets)
         start = self._loop_tables(start,
                                   classes=(np.ndarray, tf.Tensor, tf.Variable, torch.Tensor),
                                   name="Tensors",
                                   model_ids=model_ids,
                                   datasets=datasets)
         self._loop_tables(start, classes=Any, name="Miscellaneous", model_ids=model_ids, datasets=datasets)
示例#20
0
 def __init__(self, ops: Union[TensorOp, List[TensorOp]]) -> None:
     ops = to_list(ops)
     if len(ops) < 1:
         raise ValueError("Fuse requires at least one op")
     inputs = []
     outputs = []
     mode = ops[0].mode
     self.last_retain_idx = 0
     self.models = set()
     self.loss_keys = set()
     for idx, op in enumerate(ops):
         if op.mode != mode:
             raise ValueError(
                 f"All Fuse ops must share the same mode, but got {mode} and {op.mode}"
             )
         for inp in op.inputs:
             if inp not in inputs and inp not in outputs:
                 inputs.append(inp)
         for out in op.outputs:
             if out not in outputs:
                 outputs.append(out)
         if op.fe_retain_graph(
                 True) is not None:  # Set all of the internal ops to retain
             self.last_retain_idx = idx  # Keep tabs on the last one since it might be set to False
         self.models |= op.get_fe_models()
         self.loss_keys |= op.get_fe_loss_keys()
     super().__init__(inputs=inputs, outputs=outputs, mode=mode)
     self.ops = ops
示例#21
0
 def __init__(self,
              model: Model,
              model_inputs: Union[str, Sequence[str]],
              model_outputs: Union[str, Sequence[str]],
              outputs: Union[str, List[str]] = "saliency"):
     mode = "test"
     self.model_op = ModelOp(model=model,
                             mode=mode,
                             inputs=model_inputs,
                             outputs=model_outputs,
                             trainable=False)
     self.outputs = to_list(outputs)
     self.mode = mode
     self.gather_keys = [
         "SaliencyNet_Target_Index_{}".format(key)
         for key in self.model_outputs
     ]
     self.network = Network(ops=[
         Watch(inputs=self.model_inputs, mode=mode),
         self.model_op,
         Gather(inputs=self.model_outputs,
                indices=self.gather_keys,
                outputs=[
                    "SaliencyNet_Intermediate_{}".format(key)
                    for key in self.model_outputs
                ],
                mode=mode),
         GradientOp(inputs=self.model_inputs,
                    finals=[
                        "SaliencyNet_Intermediate_{}".format(key)
                        for key in self.model_outputs
                    ],
                    outputs=deepcopy(self.outputs),
                    mode=mode),
     ])
示例#22
0
 def forward(self, data, state):
     data = to_list(data)
     for idx, elem in enumerate(data):
         data[idx] = self.convert_fn(elem)
     if len(data) == 1:
         data = data[0]
     return data
示例#23
0
 def get_outputs(self, ds_ids: Union[None, str, List[str]]) -> List[str]:
     ds_ids = to_list(ds_ids)
     outputs = list(self.outputs)
     for output in self.outputs:
         for ds_id in ds_ids:
             outputs.append(f"{output}|{ds_id}")
     return outputs
示例#24
0
 def __init__(self, ops: Union[NumpyOp, List[NumpyOp]]) -> None:
     ops = to_list(ops)
     if len(ops) < 1:
         raise ValueError("Fuse requires at least one op")
     inputs = []
     outputs = []
     mode = ops[0].mode
     ds_id = ops[0].ds_id
     for op in ops:
         if isinstance(op, Batch):
             raise ValueError("Cannot nest the Batch op inside of Fuse")
         if op.mode != mode:
             raise ValueError(
                 f"All Fuse ops must share the same mode, but got {mode} and {op.mode}"
             )
         if op.ds_id != ds_id:
             raise ValueError(
                 f"All Fuse ops must share the same ds_id, but got {ds_id} and {op.ds_id}"
             )
         for inp in op.inputs:
             if isinstance(op, Delete) and inp in outputs:
                 outputs.remove(inp)
             elif inp not in inputs and inp not in outputs:
                 inputs.append(inp)
         for out in op.outputs:
             if out not in outputs:
                 outputs.append(out)
     super().__init__(inputs=inputs,
                      outputs=outputs,
                      mode=mode,
                      ds_id=ds_id)
     self.ops = ops
示例#25
0
 def __init__(self,
              train_data: Union[None, DataSource,
                                Scheduler[DataSource]] = None,
              eval_data: Union[None, DataSource,
                               Scheduler[DataSource]] = None,
              test_data: Union[None, DataSource,
                               Scheduler[DataSource]] = None,
              batch_size: Union[None, int, Scheduler[int]] = None,
              ops: Union[None, NumpyOp, Scheduler[NumpyOp],
                         List[Union[NumpyOp, Scheduler[NumpyOp]]]] = None,
              num_process: Optional[int] = None,
              drop_last: bool = False,
              pad_value: Optional[Union[int, float]] = None,
              collate_fn: Optional[Callable] = None):
     self.data = {
         x: y
         for (x, y) in zip(["train", "eval", "test"],
                           [train_data, eval_data, test_data]) if y
     }
     self.batch_size = batch_size
     self.ops = to_list(ops)
     self.num_process = num_process if num_process is not None else os.cpu_count(
     ) if os.name != 'nt' else 0
     self.drop_last = drop_last
     self.pad_value = pad_value
     self.collate_fn = collate_fn
     self._verify_inputs(
         **{k: v
            for k, v in locals().items() if k != 'self'})
示例#26
0
 def __init__(self,
              datasets: Union[FEDataset, Iterable[FEDataset]],
              num_samples: Union[int, Iterable[int]],
              probability: Optional[Iterable[float]] = None) -> None:
     self.datasets = to_list(datasets)
     self.num_samples = to_list(num_samples)
     self.probability = to_list(probability)
     self.same_feature = False
     self.all_fe_datasets = False
     self._check_input()
     self.index_maps = []
     self.child_reset_fns = [
         dataset.fe_reset_ds for dataset in self.datasets
         if hasattr(dataset, 'fe_reset_ds')
     ]
     self.fe_reset_ds(seed=0)
示例#27
0
    def __init__(self,
                 inputs: Union[str, Iterable[str]],
                 outputs: Union[str, Iterable[str]],
                 output_shape: Sequence[int],
                 resize_mode: str = 'nearest',
                 mode: Union[None, str, Iterable[str]] = None,
                 ds_id: Union[None, str, Iterable[str]] = None):

        super().__init__(inputs=to_list(inputs),
                         outputs=to_list(outputs),
                         mode=mode)
        assert resize_mode in [
            'nearest', 'area'
        ], "Only following resize modes are supported: 'nearest', 'area' "
        self.output_shape = output_shape
        self.reize_mode = resize_mode
示例#28
0
 def __init__(self,
              inputs=None,
              outputs=None,
              mode=None,
              rotation_range=0.,
              width_shift_range=0.,
              height_shift_range=0.,
              shear_range=0.,
              zoom_range=1.,
              flip_left_right=False,
              flip_up_down=False):
     super().__init__(inputs, outputs, mode)
     self.rotation_range = rotation_range
     self.width_shift_range = width_shift_range
     self.height_shift_range = height_shift_range
     self.shear_range = shear_range
     zoom_range_list = to_list(zoom_range)
     assert all([z > 0
                 for z in zoom_range_list]), "zoom range should be positive"
     self.zoom_range = zoom_range
     self.flip_left_right_boolean = flip_left_right
     self.flip_up_down_boolean = flip_up_down
     self.transform_matrix = tf.eye(3)
     self.width = None
     self.height = None
     self.do_flip_lr_tensor = tf.convert_to_tensor(0)
     self.do_flip_up_tensor = tf.convert_to_tensor(0)
示例#29
0
def reduce_max(tensor: Tensor,
               axis: Union[None, int, Sequence[int]] = None,
               keepdims: bool = False) -> Tensor:
    """Compute the maximum value along a given `axis` of a `tensor`.

    This method can be used with Numpy data:
    ```python
    n = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
    b = fe.backend.reduce_max(n)  # 8
    b = fe.backend.reduce_max(n, axis=0)  # [[5, 6], [7, 8]]
    b = fe.backend.reduce_max(n, axis=1)  # [[3, 4], [7, 8]]
    b = fe.backend.reduce_max(n, axis=[0,2])  # [6, 8]
    ```

    This method can be used with TensorFlow tensors:
    ```python
    t = tf.constant([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
    b = fe.backend.reduce_max(t)  # 8
    b = fe.backend.reduce_max(t, axis=0)  # [[5, 6], [7, 8]]
    b = fe.backend.reduce_max(t, axis=1)  # [[3, 4], [7, 8]]
    b = fe.backend.reduce_max(t, axis=[0,2])  # [6, 8]
    ```

    This method can be used with PyTorch tensors:
    ```python
    p = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
    b = fe.backend.reduce_max(p)  # 8
    b = fe.backend.reduce_max(p, axis=0)  # [[5, 6], [7, 8]]
    b = fe.backend.reduce_max(p, axis=1)  # [[3, 4], [7, 8]]
    b = fe.backend.reduce_max(p, axis=[0,2])  # [6, 8]
    ```

    Args:
        tensor: The input value.
        axis: Which axis or collection of axes to compute the maximum along.
        keepdims: Whether to preserve the number of dimensions during the reduction.

    Returns:
        The maximum values of `tensor` along `axis`.

    Raises:
        ValueError: If `tensor` is an unacceptable data type.
    """
    if tf.is_tensor(tensor):
        return tf.reduce_max(tensor, axis=axis, keepdims=keepdims)
    elif isinstance(tensor, torch.Tensor):
        if axis is None:
            axis = list(range(len(tensor.shape)))
        axis = to_list(axis)
        axis = reversed(sorted(axis))
        for ax in axis:
            tensor = tensor.max(dim=ax, keepdim=keepdims)[0]
        return tensor
    elif isinstance(tensor, np.ndarray):
        if isinstance(axis, list):
            axis = tuple(axis)
        return np.max(tensor, axis=axis, keepdims=keepdims)
    else:
        raise ValueError("Unrecognized tensor type {}".format(type(tensor)))
示例#30
0
 def __init__(
     self,
     model: Model,
     model_inputs: Union[str, Sequence[str]],
     model_outputs: Union[str, Sequence[str]],
     class_key: Optional[str] = None,
     label_mapping: Optional[Dict[str, Any]] = None,
     outputs: Union[str, List[str]] = "saliency",
     samples: Union[None, int, Dict[str, Any]] = None,
     mode: Union[None, str, Iterable[str]] = ("eval", "test"),
     ds_id: Union[None, str, Iterable[str]] = None,
     smoothing: int = 25,
     integrating: Union[int, Tuple[int, int]] = (100, 6)
 ) -> None:
     # Model outputs are required due to inability to statically determine the number of outputs from a pytorch model
     self.class_key = class_key
     self.model_outputs = to_list(model_outputs)
     super().__init__(inputs=to_list(self.class_key) +
                      to_list(model_inputs),
                      outputs=outputs,
                      mode=mode,
                      ds_id=ds_id)
     self.smoothing = smoothing
     self.integrating = integrating
     self.samples = {}
     self.n_found = {}
     self.n_required = {}
     # TODO - handle non-hashable labels
     self.label_mapping = {val: key
                           for key, val in label_mapping.items()
                           } if label_mapping else None
     for mode in mode or ("train", "eval", "test"):
         self.samples[mode] = samples
         if isinstance(samples, int):
             self.samples[mode] = None
             self.n_found[mode] = 0
             self.n_required[mode] = samples
         else:
             self.n_found[mode] = 0
             self.n_required[mode] = 0
         if self.samples[mode] is None:
             self.samples[mode] = defaultdict(list)
     self.salnet = SaliencyNet(model=model,
                               model_inputs=model_inputs,
                               model_outputs=model_outputs,
                               outputs=outputs)