def build(self, input_shape): freeze_layers = self.params.get("freeze_layers") if freeze_layers: if not isinstance(freeze_layers, list): freeze_layers = [freeze_layers] for layer_path in freeze_layers: layer = misc.index_structure(self, layer_path) layer.trainable = False misc.set_dropout(layer, 0) # Disable dropout in frozen layers. self.examples_inputter.build(input_shape) self.built = True
def trainable_weights(self): if not self._frozen_layers: self._frozen_layers = True freeze_layers = self.params.get("freeze_layers") if freeze_layers: if not isinstance(freeze_layers, list): freeze_layers = [freeze_layers] for layer_path in freeze_layers: layer = misc.index_structure(self, layer_path) layer.trainable = False return super(Model, self).trainable_weights
def trainable_weights(self): if not self._frozen_layers: self._frozen_layers = True freeze_layers = self.params.get("freeze_layers") if freeze_layers: if not isinstance(freeze_layers, list): freeze_layers = [freeze_layers] for layer_path in freeze_layers: layer = misc.index_structure(self, layer_path) layer.trainable = False tf.get_logger().info( "%d weights are frozen by the freeze_layers parameter" % (len(self.non_trainable_weights))) return super(Model, self).trainable_weights
def average_checkpoints(model_dir, output_dir, trackables, max_count=8, model_key="model"): """Averages object-based checkpoints. Args: model_dir: The directory containing checkpoints. output_dir: The directory that will contain the averaged checkpoint. trackables: A dictionary containing the trackable objects included in the checkpoint. max_count: The maximum number of checkpoints to average. model_key: The key in :obj:`trackables` that references the model. Returns: The path to the directory containing the averaged checkpoint. Raises: ValueError: if :obj:`output_dir` is the same as :obj:`model_dir`. ValueError: if a model is not found in :obj:`trackables` or is not already built. ValueError: if no checkpoints are found in :obj:`model_dir`. """ if model_dir == output_dir: raise ValueError("Model and output directory must be different") model = trackables.get(model_key) if model is None: raise ValueError("%s not found in trackables %s" % (model_key, trackables)) if not model.built: raise ValueError("The model should be built before calling this function") checkpoint = tf.train.Checkpoint(**trackables) checkpoint_manager = tf.train.CheckpointManager(checkpoint, model_dir, max_to_keep=None) checkpoints_path = checkpoint_manager.checkpoints if not checkpoints_path: raise ValueError("No checkpoints found in %s" % model_dir) if len(checkpoints_path) > max_count: checkpoints_path = checkpoints_path[-max_count:] num_checkpoints = len(checkpoints_path) last_step = int(checkpoints_path[-1].split("-")[-1]) tf.get_logger().info("Averaging %d checkpoints...", num_checkpoints) for i, checkpoint_path in enumerate(reversed(checkpoints_path)): tf.get_logger().info("Reading checkpoint %s...", checkpoint_path) if i == 0: checkpoint.restore(checkpoint_path).assert_existing_objects_matched() for variable in model.variables: variable.assign(variable / num_checkpoints) else: reader = tf.train.load_checkpoint(checkpoint_path) for path in six.iterkeys(reader.get_variable_to_shape_map()): if not path.startswith(model_key) or ".OPTIMIZER_SLOT" in path: continue variable_path = path.replace("/.ATTRIBUTES/VARIABLE_VALUE", "") variable = misc.index_structure(trackables, variable_path) value = reader.get_tensor(path) variable.assign_add(value / num_checkpoints) new_checkpoint_manager = tf.train.CheckpointManager(checkpoint, output_dir, max_to_keep=None) new_checkpoint_manager.save(checkpoint_number=last_step) return output_dir