def testInitFromCheckpoint(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)

    # New graph and session.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as session:
        with variable_scope.variable_scope("some_scope"):
          my1 = variable_scope.get_variable("my1", [1, 10])
          with variable_scope.variable_scope("some_other_scope"):
            my2 = variable_scope.get_variable("my2", [10, 10])
            with variable_scope.variable_scope("other_useful_scope"):
              my4 = variable_scope.get_variable("var4", [9, 9])
        my3 = variable_scope.get_variable("my3", [100, 100])

        checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
            "var1": "some_scope/my1",
            "useful_scope/": "some_scope/some_other_scope/other_useful_scope/",
        })
        checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
            "var2": "some_scope/some_other_scope/my2",
            "var3": my3,
        })

        session.run(variables.global_variables_initializer())
        self.assertAllEqual(my1.eval(session), v1)
        self.assertAllEqual(my2.eval(session), v2)
        self.assertAllEqual(my3.eval(session), v3)
        self.assertAllEqual(my4.eval(session), v4)

        # Check that tensors are not explicitly in the graph.
        self.assertLess(len(str(session.graph.as_graph_def())), 27000)
示例#2
0
 def begin(self):
     if self._init_checkpoint is None: return
     init_map = create_embedding_map(self._init_checkpoint)
     tf.logging.info("embeddings to be initialized from {}: {}".format(
         self._init_checkpoint, init_map))
     checkpoint_utils.init_from_checkpoint(self._init_checkpoint, init_map)
     pass
def _maybe_restore_from_checkpoint(checkpoint_path, variable):
  if checkpoint_path is not None:
    path, tensor_name = checkpoint_path
    weights_to_restore = variable
    if len(variable) == 1:
      weights_to_restore = variable[0]
    checkpoint_utils.init_from_checkpoint(path,
                                          {tensor_name: weights_to_restore})
示例#4
0
def _maybe_restore_from_checkpoint(checkpoint_path, variable):
  if checkpoint_path is not None:
    path, tensor_name = checkpoint_path
    weights_to_restore = variable
    if len(variable) == 1:
      weights_to_restore = variable[0]
    checkpoint_utils.init_from_checkpoint(path,
                                          {tensor_name: weights_to_restore})
  def testInitFromPartitionVar(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      v1 = _create_partition_checkpoints(session, checkpoint_dir)

    # New graph and session.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as session:
        with variable_scope.variable_scope("some_scope"):
          my1 = variable_scope.get_variable(
              name="my1",
              shape=[100, 100],
              initializer=init_ops.truncated_normal_initializer(0.5),
              partitioner=partitioned_variables.min_max_variable_partitioner(
                  max_partitions=5, axis=0, min_slice_size=8 << 10))
          my1_var_list = my1._get_variable_list()
        with variable_scope.variable_scope("some_other_scope"):
          my2 = variable_scope.get_variable(
              name="var1",
              shape=[100, 100],
              initializer=init_ops.truncated_normal_initializer(0.5),
              partitioner=partitioned_variables.min_max_variable_partitioner(
                  max_partitions=5, axis=0, min_slice_size=8 << 10))
          my2_var_list = my2._get_variable_list()

        checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
            "scope/var1": "some_scope/my1",
            "scope/": "some_other_scope/"})

        session.run(variables.global_variables_initializer())
        my1_values = session.run(my1_var_list)
        self.assertAllEqual(my1_values, v1)
        my2_values = session.run(my2_var_list)
        self.assertAllEqual(my2_values, v1)

    # New graph and session.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as session:
        with variable_scope.variable_scope("some_scope"):
          my1 = variable_scope.get_variable(
              name="my1",
              shape=[100, 100],
              initializer=init_ops.truncated_normal_initializer(0.5),
              partitioner=partitioned_variables.min_max_variable_partitioner(
                  max_partitions=5, axis=0, min_slice_size=8 << 10))
          my1_var_list = my1._get_variable_list()

        checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                              {"scope/var1": my1_var_list,})

        session.run(variables.global_variables_initializer())
        my1_values = session.run(my1_var_list)
        self.assertAllEqual(my1_values, v1)
    def testInitFromPartitionVar(self):
        checkpoint_dir = self.get_temp_dir()
        with self.test_session() as session:
            v1 = _create_partition_checkpoints(session, checkpoint_dir)

        # New graph and session.
        with ops.Graph().as_default() as g:
            with self.test_session(graph=g) as session:
                with variable_scope.variable_scope("some_scope"):
                    my1 = variable_scope.get_variable(
                        name="my1",
                        shape=[100, 100],
                        initializer=init_ops.truncated_normal_initializer(0.5),
                        partitioner=partitioned_variables.
                        min_max_variable_partitioner(max_partitions=5,
                                                     axis=0,
                                                     min_slice_size=8 << 10))
                    my1_var_list = my1._get_variable_list()

                checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
                    "var1": "some_scope/my1",
                })

                session.run(variables.global_variables_initializer())
                my1_values = session.run(my1_var_list)
                self.assertAllEqual(my1_values, v1)

        # New graph and session.
        with ops.Graph().as_default() as g:
            with self.test_session(graph=g) as session:
                with variable_scope.variable_scope("some_scope"):
                    my1 = variable_scope.get_variable(
                        name="my1",
                        shape=[100, 100],
                        initializer=init_ops.truncated_normal_initializer(0.5),
                        partitioner=partitioned_variables.
                        min_max_variable_partitioner(max_partitions=5,
                                                     axis=0,
                                                     min_slice_size=8 << 10))
                    my1_var_list = my1._get_variable_list()

                checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
                    "var1": my1_var_list,
                })

                session.run(variables.global_variables_initializer())
                my1_values = session.run(my1_var_list)
                self.assertAllEqual(my1_values, v1)
    def testInitWithScopeDoesNotCaptureSuffixes(self):
        checkpoint_dir = self.get_temp_dir()
        with self.cached_session() as session:
            _, _, _, v4 = _create_checkpoints(session, checkpoint_dir)

        with ops.Graph().as_default() as g:
            with variable_scope.variable_scope("useful_scope"):
                my4 = variable_scope.get_variable("var4", [9, 9])
            with variable_scope.variable_scope("useful_scope_1"):
                my5_init = [[1.0, 2.0], [3.0, 4.0]]
                my5 = variable_scope.get_variable("var5", initializer=my5_init)

            checkpoint_utils.init_from_checkpoint(
                checkpoint_dir, {"useful_scope/": "useful_scope/"})
            with self.session(graph=g) as session:
                session.run(variables.global_variables_initializer())
                self.assertAllEqual(my4.eval(session), v4)
                self.assertAllEqual(my5.eval(session), my5_init)
  def testInitToRootCheckpoint(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)

    # New graph and session.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as session:
        my1 = variable_scope.get_variable("var1", [1, 10])
        my2 = variable_scope.get_variable("var2", [10, 10])
        my3 = variable_scope.get_variable("var3", [100, 100])
        with variable_scope.variable_scope("useful_scope"):
          my4 = variable_scope.get_variable("var4", [9, 9])

        checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                              {"/": "/",})

        session.run(variables.global_variables_initializer())
        self.assertAllEqual(my1.eval(session), v1)
        self.assertAllEqual(my2.eval(session), v2)
        self.assertAllEqual(my3.eval(session), v3)
        self.assertAllEqual(my4.eval(session), v4)
    def testInitToRootCheckpoint(self):
        checkpoint_dir = self.get_temp_dir()
        with self.cached_session() as session:
            v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)

        # New graph and session.
        with ops.Graph().as_default() as g:
            with self.session(graph=g) as session:
                my1 = variable_scope.get_variable("var1", [1, 10])
                my2 = variable_scope.get_variable("var2", [10, 10])
                my3 = variable_scope.get_variable("var3", [100, 100])
                with variable_scope.variable_scope("useful_scope"):
                    my4 = variable_scope.get_variable("var4", [9, 9])

                checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
                    "/": "/",
                })

                session.run(variables.global_variables_initializer())
                self.assertAllEqual(my1.eval(session), v1)
                self.assertAllEqual(my2.eval(session), v2)
                self.assertAllEqual(my3.eval(session), v3)
                self.assertAllEqual(my4.eval(session), v4)
    def testInitFromCheckpoint(self):
        checkpoint_dir = self.get_temp_dir()
        with self.cached_session() as session:
            v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)

        # New graph and session.
        with ops.Graph().as_default() as g:
            with self.session(graph=g) as session:
                with variable_scope.variable_scope("some_scope"):
                    my1 = variable_scope.get_variable("my1", [1, 10])
                    with variable_scope.variable_scope("some_other_scope"):
                        my2 = variable_scope.get_variable("my2", [10, 10])
                        with variable_scope.variable_scope(
                                "other_useful_scope"):
                            my4 = variable_scope.get_variable("var4", [9, 9])
                my3 = variable_scope.get_variable("my3", [100, 100])

                checkpoint_utils.init_from_checkpoint(
                    checkpoint_dir, {
                        "var1":
                        "some_scope/my1",
                        "useful_scope/":
                        "some_scope/some_other_scope/other_useful_scope/",
                    })
                checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
                    "var2": "some_scope/some_other_scope/my2",
                    "var3": my3,
                })

                session.run(variables.global_variables_initializer())
                self.assertAllEqual(my1.eval(session), v1)
                self.assertAllEqual(my2.eval(session), v2)
                self.assertAllEqual(my3.eval(session), v3)
                self.assertAllEqual(my4.eval(session), v4)

                # Check that tensors are not explicitly in the graph.
                self.assertLess(len(str(session.graph.as_graph_def())), 27000)
    def testInitFromCheckpointMissing(self):
        checkpoint_dir = self.get_temp_dir()
        with self.cached_session() as session:
            _, _, _, _ = _create_checkpoints(session, checkpoint_dir)

        # New graph and session.
        with ops.Graph().as_default() as g:
            with self.session(graph=g) as session:
                with variable_scope.variable_scope("some_scope"):
                    _ = variable_scope.get_variable("my1", [10, 10])
                    _ = variable_scope.get_variable(
                        "my2", [1, 10],
                        dtype=dtypes.int64,
                        initializer=init_ops.zeros_initializer())

                # No directory.
                with self.assertRaises(errors_impl.OpError):
                    checkpoint_utils.init_from_checkpoint(
                        "no_dir", {"var1": "some_scope/my1"})

                # No variable in checkpoint.
                with self.assertRaises(ValueError):
                    checkpoint_utils.init_from_checkpoint(
                        checkpoint_dir, {"no_var": "some_scope/my1"})

                # No variable in the graph.
                with self.assertRaises(ValueError):
                    checkpoint_utils.init_from_checkpoint(
                        checkpoint_dir, {"var3": "some_scope/no_var"})

                # Shape mismatch.
                with self.assertRaises(ValueError):
                    checkpoint_utils.init_from_checkpoint(
                        checkpoint_dir, {"var1": "some_scope/my1"})

                # Variable 'my1' and 'my2' are missing in given checkpoint scope.
                with self.assertRaises(ValueError):
                    checkpoint_utils.init_from_checkpoint(
                        checkpoint_dir, {"useful_scope/": "some_scope/"})

                # Mapping is not to scope name.
                with self.assertRaises(ValueError):
                    checkpoint_utils.init_from_checkpoint(
                        checkpoint_dir, {"useful_scope": "some_scope/"})
示例#12
0
def init_from_checkpoint(checkpoint_dir, assignment_map):
  """See `tf.contrib.framework.init_from_checkpoint`."""
  checkpoint_utils.init_from_checkpoint(checkpoint_dir, assignment_map)
  def testInitFromCheckpointMissing(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      _, _, _, _ = _create_checkpoints(session, checkpoint_dir)

    # New graph and session.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as session:
        with variable_scope.variable_scope("some_scope"):
          _ = variable_scope.get_variable("my1", [10, 10])
          _ = variable_scope.get_variable(
              "my2", [1, 10],
              dtype=dtypes.int64,
              initializer=init_ops.zeros_initializer())

        # No directory.
        with self.assertRaises(errors_impl.OpError):
          checkpoint_utils.init_from_checkpoint("no_dir",
                                                {"var1": "some_scope/my1"})

        # No variable in checkpoint.
        with self.assertRaises(ValueError):
          checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                                {"no_var": "some_scope/my1"})

        # No variable in the graph.
        with self.assertRaises(ValueError):
          checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                                {"var3": "some_scope/no_var"})

        # Shape mismatch.
        with self.assertRaises(ValueError):
          checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                                {"var1": "some_scope/my1"})

        # Variable 'my1' and 'my2' are missing in given checkpoint scope.
        with self.assertRaises(ValueError):
          checkpoint_utils.init_from_checkpoint(
              checkpoint_dir, {"useful_scope/": "some_scope/"})

        # Mapping is not to scope name.
        with self.assertRaises(ValueError):
          checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                                {"useful_scope": "some_scope/"})
示例#14
0
def init_from_checkpoint(checkpoint_dir, assignment_map):
    """See `tf.contrib.framework.init_from_checkpoint`."""
    checkpoint_utils.init_from_checkpoint(checkpoint_dir, assignment_map)
示例#15
0
def restore_from_checkpoint(chkpt_dir_or_file,
                            restore_chkpt_scopes=None,
                            chkpt_to_graph_scope_map=None):
    if chkpt_to_graph_scope_map:
        print("map graph scopes to checkpoint scopes: {}".format(
            chkpt_to_graph_scope_map))

    print("list variables in checkpoint:")
    list_of_name_and_shape = tf.train.list_variables(chkpt_dir_or_file)
    init_assignment_map = dict()
    for chkpt_var_name, chkpt_shape in list_of_name_and_shape:
        if '/Adam' in str(chkpt_var_name):
            continue
        if chkpt_to_graph_scope_map is None or len(
                chkpt_to_graph_scope_map) == 0:
            graph_var_name = chkpt_var_name
        else:
            graph_var_name = None
            for c_scope, g_scope in chkpt_to_graph_scope_map.items():
                if str(chkpt_var_name).startswith(c_scope):
                    graph_var_name = str(chkpt_var_name).replace(
                        c_scope, g_scope)
                    break
            if graph_var_name is None:
                continue
        graph_var = check_if_variable(graph_var_name)
        graph_shape = None
        if graph_var is not None:
            if not isinstance(graph_var, list):
                graph_shape = graph_var.get_shape().as_list()
            else:
                graph_shape = graph_var[0].get_shape().as_list()
                for i in range(1, len(graph_var)):
                    graph_shape[0] += graph_var[i].get_shape().as_list()[0]
        if graph_var is None:
            print(
                "var not found in graph: name={} checkpoint_name={} checkpoint_shape={}"
                .format(graph_var_name, chkpt_var_name, chkpt_shape))
        elif graph_shape != chkpt_shape:
            print(
                "bad shape: name={} checkpoint_name={} shape={} checkpoint_shape={} var={}"
                .format(graph_var_name, chkpt_var_name, graph_shape,
                        chkpt_shape, graph_var))
        else:
            print(
                "ready to load: name={} checkpoint_name={} shape={} checkpoint_shape={} var={}"
                .format(graph_var_name, chkpt_var_name, graph_shape,
                        chkpt_shape, graph_var))
            init_assignment_map[chkpt_var_name] = graph_var_name

    assignment_map = dict()
    if restore_chkpt_scopes is not None and len(restore_chkpt_scopes) > 0:
        print("load var in checkpoint scope: {}".format(
            ", ".join(restore_chkpt_scopes)))
        for k, v in init_assignment_map.items():
            for var_scope in restore_chkpt_scopes:
                if k.startswith(var_scope):
                    assignment_map[k] = v
    else:
        print('load all possible var from checkpoint')
        assignment_map = init_assignment_map

    print("init_from_checkpoint: " + pprint.pformat(assignment_map))
    checkpoint_utils.init_from_checkpoint(chkpt_dir_or_file, assignment_map)