Beispiel #1
0
 def build(self, input_shape):
     freeze_layers = self.params.get("freeze_layers")
     if freeze_layers:
         if not isinstance(freeze_layers, list):
             freeze_layers = [freeze_layers]
         for layer_path in freeze_layers:
             layer = misc.index_structure(self, layer_path)
             layer.trainable = False
             misc.set_dropout(layer, 0)  # Disable dropout in frozen layers.
     self.examples_inputter.build(input_shape)
     self.built = True
Beispiel #2
0
 def trainable_weights(self):
     if not self._frozen_layers:
         self._frozen_layers = True
         freeze_layers = self.params.get("freeze_layers")
         if freeze_layers:
             if not isinstance(freeze_layers, list):
                 freeze_layers = [freeze_layers]
             for layer_path in freeze_layers:
                 layer = misc.index_structure(self, layer_path)
                 layer.trainable = False
     return super(Model, self).trainable_weights
Beispiel #3
0
 def trainable_weights(self):
     if not self._frozen_layers:
         self._frozen_layers = True
         freeze_layers = self.params.get("freeze_layers")
         if freeze_layers:
             if not isinstance(freeze_layers, list):
                 freeze_layers = [freeze_layers]
             for layer_path in freeze_layers:
                 layer = misc.index_structure(self, layer_path)
                 layer.trainable = False
             tf.get_logger().info(
                 "%d weights are frozen by the freeze_layers parameter" %
                 (len(self.non_trainable_weights)))
     return super(Model, self).trainable_weights
Beispiel #4
0
def average_checkpoints(model_dir,
                        output_dir,
                        trackables,
                        max_count=8,
                        model_key="model"):
  """Averages object-based checkpoints.

  Args:
    model_dir: The directory containing checkpoints.
    output_dir: The directory that will contain the averaged checkpoint.
    trackables: A dictionary containing the trackable objects included in the
      checkpoint.
    max_count: The maximum number of checkpoints to average.
    model_key: The key in :obj:`trackables` that references the model.

  Returns:
    The path to the directory containing the averaged checkpoint.

  Raises:
    ValueError: if :obj:`output_dir` is the same as :obj:`model_dir`.
    ValueError: if a model is not found in :obj:`trackables` or is not already
      built.
    ValueError: if no checkpoints are found in :obj:`model_dir`.
  """
  if model_dir == output_dir:
    raise ValueError("Model and output directory must be different")
  model = trackables.get(model_key)
  if model is None:
    raise ValueError("%s not found in trackables %s" % (model_key, trackables))
  if not model.built:
    raise ValueError("The model should be built before calling this function")

  checkpoint = tf.train.Checkpoint(**trackables)
  checkpoint_manager = tf.train.CheckpointManager(checkpoint, model_dir, max_to_keep=None)

  checkpoints_path = checkpoint_manager.checkpoints
  if not checkpoints_path:
    raise ValueError("No checkpoints found in %s" % model_dir)
  if len(checkpoints_path) > max_count:
    checkpoints_path = checkpoints_path[-max_count:]
  num_checkpoints = len(checkpoints_path)
  last_step = int(checkpoints_path[-1].split("-")[-1])

  tf.get_logger().info("Averaging %d checkpoints...", num_checkpoints)
  for i, checkpoint_path in enumerate(reversed(checkpoints_path)):
    tf.get_logger().info("Reading checkpoint %s...", checkpoint_path)
    if i == 0:
      checkpoint.restore(checkpoint_path).assert_existing_objects_matched()
      for variable in model.variables:
        variable.assign(variable / num_checkpoints)
    else:
      reader = tf.train.load_checkpoint(checkpoint_path)
      for path in six.iterkeys(reader.get_variable_to_shape_map()):
        if not path.startswith(model_key) or ".OPTIMIZER_SLOT" in path:
          continue
        variable_path = path.replace("/.ATTRIBUTES/VARIABLE_VALUE", "")
        variable = misc.index_structure(trackables, variable_path)
        value = reader.get_tensor(path)
        variable.assign_add(value / num_checkpoints)

  new_checkpoint_manager = tf.train.CheckpointManager(checkpoint, output_dir, max_to_keep=None)
  new_checkpoint_manager.save(checkpoint_number=last_step)
  return output_dir