示例#1
0
 def on_batch_end(self, data: Data) -> None:
     if self.write_graph and self.system.network.epoch_models.symmetric_difference(
             self.painted_graphs):
         self.writer.write_epoch_models(mode=self.system.mode)
         self.painted_graphs = self.system.network.epoch_models
     if self.system.mode != 'train':
         return
     if self.histogram_freq.freq and self.histogram_freq.is_step and \
             self.system.global_step % self.histogram_freq.freq == 0:
         self.writer.write_weights(mode=self.system.mode,
                                   models=self.system.network.models,
                                   step=self.system.global_step,
                                   visualize=self.paint_weights)
     if self.update_freq.freq and self.update_freq.is_step and self.system.global_step % self.update_freq.freq == 0:
         self.writer.write_scalars(mode=self.system.mode,
                                   step=self.system.global_step,
                                   scalars=filter(lambda x: is_number(x[1]),
                                                  data.items()))
         self.writer.write_images(mode=self.system.mode,
                                  step=self.system.global_step,
                                  images=filter(
                                      lambda x: x[1] is not None,
                                      map(lambda y: (y, data.get(y)),
                                          self.write_images)))
         self.writer.write_embeddings(
             mode=self.system.mode,
             step=self.system.global_step,
             embeddings=filter(
                 lambda x: x[1] is not None,
                 map(
                     lambda t:
                     (t[0], data.get(t[0]), data.get(t[1]), data.get(t[2])),
                     self.write_embeddings)))
示例#2
0
 def on_epoch_end(self, data: Data) -> None:
     if self.system.mode == 'train' and self.histogram_freq.freq and not self.histogram_freq.is_step and \
             self.system.epoch_idx % self.histogram_freq.freq == 0:
         self.writer.write_weights(mode=self.system.mode,
                                   models=self.system.network.models,
                                   step=self.system.global_step,
                                   visualize=self.paint_weights)
     if self.update_freq.freq and (self.update_freq.is_step
                                   or self.system.epoch_idx %
                                   self.update_freq.freq == 0):
         self.writer.write_scalars(mode=self.system.mode,
                                   step=self.system.global_step,
                                   scalars=filter(lambda x: is_number(x[1]),
                                                  data.items()))
         self.writer.write_images(mode=self.system.mode,
                                  step=self.system.global_step,
                                  images=filter(
                                      lambda x: x[1] is not None,
                                      map(lambda y: (y, data.get(y)),
                                          self.write_images)))
         self.writer.write_embeddings(
             mode=self.system.mode,
             step=self.system.global_step,
             embeddings=filter(
                 lambda x: x[1] is not None,
                 map(
                     lambda t:
                     (t[0], data.get(t[0]), data.get(t[1]), data.get(t[2])),
                     self.write_embeddings)))
示例#3
0
def load_and_interpret(model_path, input_paths, baseline=-1, input_type='float32', dictionary_path=None,
                       strip_alpha=False, smooth_factor=7, save=False, save_dir=None):
    """ A helper class to load input and invoke the interpretation api

    Args:
        model_path: The path the model file (str)
        input_paths: The paths to model input files [(str),...]
        baseline: Either a number corresponding to the baseline for integration, or a path to a baseline file
        input_type: The data type of the model inputs, ex 'float32'
        dictionary_path: The path to a dictionary file encoding a 'class_idx'->'class_name' mapping
        strip_alpha: Whether to collapse alpha channels when loading an input (bool)
        smooth_factor: How many iterations of the smoothing algorithm to run (int)
        save: Whether to save (True) or display (False) the resulting image
        save_dir: Where to save the image if save=True
    """
    model_dir = os.path.dirname(model_path)
    if save_dir is None:
        save_dir = model_dir
    network = keras.models.load_model(model_path)
    dic = load_dict(dictionary_path)
    inputs = [load_image(input_paths[i], strip_alpha=strip_alpha) for i in range(len(input_paths))]
    max_shapes = np.maximum.reduce([inp.shape for inp in inputs], axis=0)
    tf_image = tf.stack([tf.image.resize_with_crop_or_pad(tf.convert_to_tensor(im, dtype=input_type), max_shapes[0],
                                                          max_shapes[1]) for im in inputs], axis=0)
    if is_number(baseline):
        baseline_gen = tf.constant_initializer(float(baseline))
        baseline_image = baseline_gen(shape=tf_image.shape, dtype=input_type)
    else:
        baseline_image = load_image(baseline)
        baseline_image = tf.convert_to_tensor(baseline_image, dtype=input_type)

    interpret_model(network, tf_image, baseline_input=baseline_image, decode_dictionary=dic, smooth=smooth_factor,
                    save=save, save_path=save_dir)
示例#4
0
def load_and_saliency(model_path,
                      input_paths,
                      baseline=-1,
                      dictionary_path=None,
                      strip_alpha=False,
                      smooth_factor=7,
                      save=False,
                      save_dir=None):
    """A helper class to load input and invoke the saliency api

    Args:
        model_path: The path the model file (str)
        input_paths: The paths to model input files [(str),...] or to a folder of inputs [(str)]
        baseline: Either a number corresponding to the baseline for integration, or a path to a baseline file
        dictionary_path: The path to a dictionary file encoding a 'class_idx'->'class_name' mapping
        strip_alpha: Whether to collapse alpha channels when loading an input (bool)
        smooth_factor: How many iterations of the smoothing algorithm to run (int)
        save: Whether to save (True) or display (False) the resulting image
        save_dir: Where to save the image if save=True
    """
    model_dir = os.path.dirname(model_path)
    if save_dir is None:
        save_dir = model_dir
    if not save:
        save_dir = None
    network = keras.models.load_model(model_path, compile=False)
    input_type = network.input.dtype
    input_shape = network.input.shape
    n_channels = 0 if len(input_shape) == 3 else input_shape[3]

    dic = load_dict(dictionary_path)
    if len(input_paths) == 1 and os.path.isdir(input_paths[0]):
        loader = PathLoader(input_paths[0])
        input_paths = [path[0] for path in loader.path_pairs]
    inputs = [
        load_image(input_paths[i],
                   strip_alpha=strip_alpha,
                   channels=n_channels) for i in range(len(input_paths))
    ]
    max_shapes = np.maximum.reduce([inp.shape for inp in inputs], axis=0)
    tf_image = tf.stack([
        tf.image.resize_with_crop_or_pad(
            tf.convert_to_tensor(im, dtype=input_type), max_shapes[0],
            max_shapes[1]) for im in inputs
    ],
                        axis=0)
    if is_number(baseline):
        baseline_gen = tf.constant_initializer(float(baseline))
        baseline_image = baseline_gen(shape=tf_image.shape, dtype=input_type)
    else:
        baseline_image = load_image(baseline)
        baseline_image = tf.convert_to_tensor(baseline_image, dtype=input_type)

    visualize_saliency(network,
                       tf_image,
                       baseline_input=baseline_image,
                       decode_dictionary=dic,
                       smooth=smooth_factor,
                       save_path=save_dir)
示例#5
0
 def _infer_keys(self, state):
     monitored_keys = []
     for key, val in state.items():
         if isinstance(val, str) or is_number(val):
             monitored_keys.append(key)
         elif hasattr(val, "numpy") and len(val.numpy().shape) == 1:
             monitored_keys.append(key)
     self.keys = sorted(monitored_keys)
示例#6
0
 def on_epoch_end(self, data: Data) -> None:
     if self.system.mode == 'train' and self.histogram_freq.freq and not self.histogram_freq.is_step and \
             self.system.epoch_idx % self.histogram_freq.freq == 0:
         self.writer.write_weights(mode=self.system.mode,
                                   models=self.system.network.models,
                                   step=self.system.global_step,
                                   visualize=self.paint_weights)
     # Write out any embeddings which were aggregated over batches
     for name, val_list in self.collected_embeddings.items():
         embeddings = None if any(
             x[0] is None
             for x in val_list) else concat([x[0] for x in val_list])
         labels = None if any(
             x[1] is None
             for x in val_list) else concat([x[1] for x in val_list])
         imgs = None if any(
             x[2] is None
             for x in val_list) else concat([x[2] for x in val_list])
         self.writer.write_embeddings(mode=self.system.mode,
                                      step=self.system.global_step,
                                      embeddings=[(name, embeddings, labels,
                                                   imgs)])
     self.collected_embeddings.clear()
     # Get any embeddings which were generated externally on epoch end
     if self.embedding_freq.freq and (self.embedding_freq.is_step
                                      or self.system.epoch_idx %
                                      self.embedding_freq.freq == 0):
         self.writer.write_embeddings(
             mode=self.system.mode,
             step=self.system.global_step,
             embeddings=filter(
                 lambda x: x[1] is not None,
                 map(
                     lambda t:
                     (t[0], data.get(t[0]), data.get(t[1]), data.get(t[2])),
                     self.write_embeddings)))
     if self.update_freq.freq and (self.update_freq.is_step
                                   or self.system.epoch_idx %
                                   self.update_freq.freq == 0):
         self.writer.write_scalars(mode=self.system.mode,
                                   step=self.system.global_step,
                                   scalars=filter(lambda x: is_number(x[1]),
                                                  data.items()))
         self.writer.write_images(mode=self.system.mode,
                                  step=self.system.global_step,
                                  images=filter(
                                      lambda x: x[1] is not None,
                                      map(lambda y: (y, data.get(y)),
                                          self.write_images)))
示例#7
0
 def on_epoch_end(self, state):
     with self.summary_writers[state['mode']].as_default():
         for key in state.keys() - self.ignore_keys:
             val = state[key]
             if is_number(val):
                 tf.summary.scalar("epoch_" + key, val, step=state['epoch'])
         for key in self.write_images - {True, False}:
             data = state.get(key)
             if data is not None:
                 tf.summary.image(key, data, step=state['epoch'])
         if state['mode'] == 'train' and self.histogram_freq and state[
                 'epoch'] % self.histogram_freq == 0:
             self._log_weights(epoch=state['epoch'])
         if state['mode'] == 'train' and self.embeddings_freq and state[
                 'epoch'] % self.embeddings_freq == 0:
             self._log_embeddings(state)
示例#8
0
 def on_batch_end(self, data: Data) -> None:
     if self.write_graph and self.system.network.epoch_models.symmetric_difference(
             self.painted_graphs):
         self.writer.write_epoch_models(mode=self.system.mode,
                                        epoch=self.system.epoch_idx)
         self.painted_graphs = self.system.network.epoch_models
     # Collect embeddings if present in batch but viewing per epoch. Don't aggregate during training though
     if self.system.mode != 'train' and self.embedding_freq.freq and not self.embedding_freq.is_step and self.system.epoch_idx % self.embedding_freq.freq == 0:
         for elem in self.write_embeddings:
             name, lbl, img = elem
             if name in data:
                 self.collected_embeddings[name].append(
                     (data.get(name), data.get(lbl), data.get(img)))
     # Handle embeddings if viewing per step
     if self.embedding_freq.freq and self.embedding_freq.is_step and self.system.global_step % self.embedding_freq.freq == 0:
         self.writer.write_embeddings(
             mode=self.system.mode,
             step=self.system.global_step,
             embeddings=filter(
                 lambda x: x[1] is not None,
                 map(
                     lambda t:
                     (t[0], data.get(t[0]), data.get(t[1]), data.get(t[2])),
                     self.write_embeddings)))
     if self.system.mode != 'train':
         return
     if self.histogram_freq.freq and self.histogram_freq.is_step and \
             self.system.global_step % self.histogram_freq.freq == 0:
         self.writer.write_weights(mode=self.system.mode,
                                   models=self.system.network.models,
                                   step=self.system.global_step,
                                   visualize=self.paint_weights)
     if self.update_freq.freq and self.update_freq.is_step and self.system.global_step % self.update_freq.freq == 0:
         self.writer.write_scalars(mode=self.system.mode,
                                   step=self.system.global_step,
                                   scalars=filter(lambda x: is_number(x[1]),
                                                  data.items()))
         self.writer.write_images(mode=self.system.mode,
                                  step=self.system.global_step,
                                  images=filter(
                                      lambda x: x[1] is not None,
                                      map(lambda y: (y, data.get(y)),
                                          self.write_images)))
示例#9
0
 def on_batch_end(self, state):
     if state['mode'] != 'train':
         return
     if self.is_tracing:
         self._log_trace(state['train_step'])
     elif state['train_step'] == self.profile_batch - 1:
         self._enable_trace()
     if self.update_freq == 'epoch' or state[
             'train_step'] % self.update_freq != 0:
         return
     with self.summary_writers[state['mode']].as_default():
         for key in state.keys() - self.ignore_keys:
             val = state[key]
             if is_number(val):
                 tf.summary.scalar("batch_" + key,
                                   val,
                                   step=state['train_step'])
         for key in self.write_images - {True, False}:
             data = state.get(key)
             if data is not None:
                 tf.summary.image(key, data, step=state['train_step'])