def print_tensors_in_checkpoint_file(file_name, tensor_name): """Prints tensors in a checkpoint file. If no `tensor_name` is provided, prints the tensor names and shapes in the checkpoint file. If `tensor_name` is provided, prints the content of the tensor. Args: file_name: Name of the checkpoint file. tensor_name: Name of the tensor in the checkpoint file to print. """ try: if not tensor_name: variables = checkpoints.list_variables(file_name) for name, shape in variables: print("%s\t%s" % (name, str(shape))) else: print("tensor_name: ", tensor_name) print(checkpoints.load_variable(file_name, tensor_name)) except Exception as e: # pylint: disable=broad-except print(str(e)) if "corrupted compressed block contents" in str(e): print("It's likely that your checkpoint file has been compressed " "with SNAPPY.")
def get_variable_names(self): """Returns list of all variable names in this model. Returns: List of names. """ return [name for name, _ in checkpoints.list_variables(self.model_dir)]
def testGetAllVariables(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: _create_checkpoints(session, checkpoint_dir) self.assertEqual(checkpoints.list_variables(checkpoint_dir), [("useful_scope/var4", [9, 9]), ("var1", [1, 10]), ("var2", [10, 10]), ("var3", [100, 100])])
def weights_(self): values = {} optimizer_regex = r".*/" + self._optimizer.get_name() + r"(_\d)?$" for name, _ in checkpoints.list_variables(self._model_dir): if name.startswith("linear/") and name != "linear/bias_weight" and not re.match(optimizer_regex, name): values[name] = checkpoints.load_variable(self._model_dir, name) if len(values) == 1: return values[list(values.keys())[0]] return values
def weights_(self): values = {} optimizer_regex = r".*/" + self._optimizer.get_name() + r"(_\d)?$" for name, _ in checkpoints.list_variables(self._model_dir): if (name.startswith("linear/") and name != "linear/bias_weight" and not re.match(optimizer_regex, name)): values[name] = checkpoints.load_variable(self._model_dir, name) if len(values) == 1: return values[list(values.keys())[0]] return values
def get_weights(self, model_dir): """Returns weights per feature of the linear part. Args: model_dir: Directory where model parameters, graph and etc. are saved. Returns: The weights created by this model (without the optimizer weights). """ all_variables = [name for name, _ in checkpoints.list_variables(model_dir)] values = {} optimizer_regex = r".*/" + self._get_optimizer().get_name() + r"(_\d)?$" for name in all_variables: if (name.startswith(self._scope + "/") and name != self._scope + "/bias_weight" and not re.match(optimizer_regex, name)): values[name] = checkpoints.load_variable(model_dir, name) if len(values) == 1: return values[list(values.keys())[0]] return values
def get_weights(self, model_dir): """Returns weights per feature of the linear part. Args: model_dir: Directory where model parameters, graph and etc. are saved. Returns: The weights created by this model (without the optimizer weights). """ all_variables = [ name for name, _ in checkpoints.list_variables(model_dir) ] values = {} optimizer_regex = r".*/" + self._get_optimizer().get_name( ) + r"(_\d)?$" for name in all_variables: if (name.startswith(self._scope + "/") and name != self._scope + "/bias_weight" and not re.match(optimizer_regex, name)): values[name] = checkpoints.load_variable(model_dir, name) if len(values) == 1: return values[list(values.keys())[0]] return values
def get_variable_names(self): return [name for name, _ in checkpoints.list_variables(self._model_dir)]