def testInitFromCheckpoint(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)

    # New graph and session.
    with tf.Graph().as_default() as g:
      with self.test_session(graph=g) as session:
        with tf.variable_scope("some_scope"):
          my1 = tf.get_variable("my1", [1, 10])
          with tf.variable_scope("some_other_scope"):
            my2 = tf.get_variable("my2", [10, 10])
            with tf.variable_scope("other_useful_scope"):
              my4 = tf.get_variable("var4", [9, 9])
        my3 = tf.get_variable("my3", [100, 100])

        checkpoints.init_from_checkpoint(checkpoint_dir, {
            "some_scope/my1": "var1",
            "some_scope/some_other_scope/other_useful_scope/": "useful_scope/",
        })
        checkpoints.init_from_checkpoint(checkpoint_dir, {
            "some_scope/some_other_scope/my2": "var2",
            my3: "var3",
        })

        session.run(tf.initialize_all_variables())
        self.assertAllEqual(my1.eval(session), v1)
        self.assertAllEqual(my2.eval(session), v2)
        self.assertAllEqual(my3.eval(session), v3)
        self.assertAllEqual(my4.eval(session), v4)

        # Check that tensors are not explicitly in the graph.
        self.assertLess(len(str(session.graph.as_graph_def())), 26000)
    def testInitFromPartitionVar(self):
        checkpoint_dir = self.get_temp_dir()
        with self.test_session() as session:
            v1 = _create_partition_checkpoints(session, checkpoint_dir)

        # New graph and session.
        with tf.Graph().as_default() as g:
            with self.test_session(graph=g) as session:
                with tf.variable_scope("some_scope"):
                    # TODO(ipolosukhin): Enable this when get_variable partitioning works.
                    # Currently get_variable with partitioner doesn't return Variable,
                    # but returns a concat op.
                    #           my1 = tf.get_variable(
                    #               "my1", [100, 100],
                    #               partitioner=tf.variable_axis_size_partitioner(axis=0,
                    #                                                          max_shard_bytes=100))
                    my1 = tf.create_partitioned_variables(
                        shape=[100, 100],
                        slicing=[5, 1],
                        name="my1",
                        initializer=tf.truncated_normal_initializer(0.5))

                checkpoints.init_from_checkpoint(checkpoint_dir, {
                    "some_scope/my1": "var1",
                })

                session.run(tf.initialize_all_variables())
                my1_values = session.run(my1)
                self.assertAllEqual(my1_values, v1)
Beispiel #3
0
  def testInitFromCheckpoint(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)

    # New graph and session.
    with tf.Graph().as_default() as g:
      with self.test_session(graph=g) as session:
        with tf.variable_scope("some_scope"):
          my1 = tf.get_variable("my1", [1, 10])
          with tf.variable_scope("some_other_scope"):
            my2 = tf.get_variable("my2", [10, 10])
            with tf.variable_scope("other_useful_scope"):
              my4 = tf.get_variable("var4", [9, 9])
        my3 = tf.get_variable("my3", [100, 100])

        checkpoints.init_from_checkpoint(checkpoint_dir, {
            "some_scope/my1": "var1",
            "some_scope/some_other_scope/other_useful_scope/": "useful_scope/",
        })
        checkpoints.init_from_checkpoint(checkpoint_dir, {
            "some_scope/some_other_scope/my2": "var2",
            "my3": "var3",
        })

        session.run(tf.initialize_all_variables())
        self.assertAllEqual(my1.eval(session), v1)
        self.assertAllEqual(my2.eval(session), v2)
        self.assertAllEqual(my3.eval(session), v3)
        self.assertAllEqual(my4.eval(session), v4)

        # Check that tensors are not explicitly in the graph.
        self.assertLess(len(str(session.graph.as_graph_def())), 22000)
Beispiel #4
0
  def testInitFromPartitionVar(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      v1 = _create_partition_checkpoints(session, checkpoint_dir)

    # New graph and session.
    with tf.Graph().as_default() as g:
      with self.test_session(graph=g) as session:
        with tf.variable_scope("some_scope"):
          # TODO(ipolosukhin): Enable this when get_variable partitioning works.
          # Currently get_variable with partitioner doesn't return Variable,
          # but returns a concat op.
#           my1 = tf.get_variable(
#               "my1", [100, 100],
#               partitioner=tf.variable_axis_size_partitioner(axis=0,
#                                                          max_shard_bytes=100))
          my1 = tf.create_partitioned_variables(
              shape=[100, 100], slicing=[5, 1], name="my1",
              initializer=tf.truncated_normal_initializer(0.5))

        checkpoints.init_from_checkpoint(checkpoint_dir, {
            "some_scope/my1": "var1",
        })

        session.run(tf.initialize_all_variables())
        my1_values = session.run(my1)
        self.assertAllEqual(my1_values, v1)
    def testInitFromRootCheckpoint(self):
        checkpoint_dir = self.get_temp_dir()
        with self.test_session() as session:
            v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)

        # New graph and session.
        with tf.Graph().as_default() as g:
            with self.test_session(graph=g) as session:
                with tf.variable_scope("some_scope"):
                    my1 = tf.get_variable("var1", [1, 10])
                    my2 = tf.get_variable("var2", [10, 10])
                    my3 = tf.get_variable("var3", [100, 100])
                    with tf.variable_scope("useful_scope"):
                        my4 = tf.get_variable("var4", [9, 9])

                checkpoints.init_from_checkpoint(checkpoint_dir, {
                    "some_scope/": "/",
                })

                session.run(tf.initialize_all_variables())
                self.assertAllEqual(my1.eval(session), v1)
                self.assertAllEqual(my2.eval(session), v2)
                self.assertAllEqual(my3.eval(session), v3)
                self.assertAllEqual(my4.eval(session), v4)
Beispiel #6
0
  def testInitFromRootCheckpoint(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)

    # New graph and session.
    with tf.Graph().as_default() as g:
      with self.test_session(graph=g) as session:
        with tf.variable_scope("some_scope"):
          my1 = tf.get_variable("var1", [1, 10])
          my2 = tf.get_variable("var2", [10, 10])
          my3 = tf.get_variable("var3", [100, 100])
          with tf.variable_scope("useful_scope"):
            my4 = tf.get_variable("var4", [9, 9])

        checkpoints.init_from_checkpoint(checkpoint_dir, {
            "some_scope/": "/",
        })

        session.run(tf.initialize_all_variables())
        self.assertAllEqual(my1.eval(session), v1)
        self.assertAllEqual(my2.eval(session), v2)
        self.assertAllEqual(my3.eval(session), v3)
        self.assertAllEqual(my4.eval(session), v4)
    def testInitFromCheckpointMissing(self):
        checkpoint_dir = self.get_temp_dir()
        with self.test_session() as session:
            _, _, _, _ = _create_checkpoints(session, checkpoint_dir)

        # New graph and session.
        with tf.Graph().as_default() as g:
            with self.test_session(graph=g) as session:
                with tf.variable_scope("some_scope"):
                    _ = tf.get_variable("my1", [10, 10])
                    _ = tf.get_variable("my2", [1, 10],
                                        dtype=tf.int64,
                                        initializer=tf.zeros_initializer)

                # No directory.
                with self.assertRaises(tf.errors.OpError):
                    checkpoints.init_from_checkpoint(
                        "no_dir", {"some_scope/my1": "var1"})

                # No variable in checkpoint.
                with self.assertRaises(ValueError):
                    checkpoints.init_from_checkpoint(
                        checkpoint_dir, {"some_scope/my1": "no_var"})

                # No variable in the graph.
                with self.assertRaises(ValueError):
                    checkpoints.init_from_checkpoint(
                        checkpoint_dir, {"some_scope/no_var": "var3"})

                # Shape mismatch.
                with self.assertRaises(ValueError):
                    checkpoints.init_from_checkpoint(
                        checkpoint_dir, {"some_scope/my1": "var1"})

                # Variable 'my1' and 'my2' are missing in given checkpoint scope.
                with self.assertRaises(ValueError):
                    checkpoints.init_from_checkpoint(
                        checkpoint_dir, {"some_scope/": "useful_scope/"})

                # Mapping is not to scope name.
                with self.assertRaises(ValueError):
                    checkpoints.init_from_checkpoint(
                        checkpoint_dir, {"some_scope/": "useful_scope"})
Beispiel #8
0
  def testInitFromCheckpointMissing(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      _, _, _, _ = _create_checkpoints(session, checkpoint_dir)

    # New graph and session.
    with tf.Graph().as_default() as g:
      with self.test_session(graph=g) as session:
        with tf.variable_scope("some_scope"):
          _ = tf.get_variable("my1", [10, 10])
          _ = tf.get_variable("my2", [1, 10],
                              dtype=tf.int64, initializer=tf.zeros_initializer)

        # No directory.
        with self.assertRaises(tf.errors.OpError):
          checkpoints.init_from_checkpoint("no_dir", {
              "some_scope/my1": "var1"})

        # No variable in checkpoint.
        with self.assertRaises(ValueError):
          checkpoints.init_from_checkpoint(checkpoint_dir, {
              "some_scope/my1": "no_var"})

        # No variable in the graph.
        with self.assertRaises(ValueError):
          checkpoints.init_from_checkpoint(checkpoint_dir, {
              "some_scope/no_var": "var3"})

        # Shape mismatch.
        with self.assertRaises(ValueError):
          checkpoints.init_from_checkpoint(checkpoint_dir, {
              "some_scope/my1": "var1"})

        # DType mismatch.
        with self.assertRaises(ValueError):
          checkpoints.init_from_checkpoint(checkpoint_dir, {
              "some_scope/my2": "var1"})

        # Variable 'my1' and 'my2' are missing in given checkpoint scope.
        with self.assertRaises(ValueError):
          checkpoints.init_from_checkpoint(checkpoint_dir, {
              "some_scope/": "useful_scope/"})

        # Mapping is not to scope name.
        with self.assertRaises(ValueError):
          checkpoints.init_from_checkpoint(checkpoint_dir, {
              "some_scope/": "useful_scope"})