def testSaveRestoreNumpyState(self):
        directory = self.get_temp_dir()
        prefix = os.path.join(directory, "ckpt")
        save_state = python_state.NumpyState()
        saver = util.Checkpoint(numpy=save_state)
        save_state.a = numpy.ones([2, 2])
        save_state.b = numpy.ones([2, 2])
        save_state.b = numpy.zeros([2, 2])
        self.assertAllEqual(numpy.ones([2, 2]), save_state.a)
        self.assertAllEqual(numpy.zeros([2, 2]), save_state.b)
        first_save_path = saver.save(prefix)
        save_state.a[1, 1] = 2.
        second_save_path = saver.save(prefix)

        load_state = python_state.NumpyState()
        loader = util.Checkpoint(numpy=load_state)
        loader.restore(first_save_path).initialize_or_restore()
        self.assertAllEqual(numpy.ones([2, 2]), load_state.a)
        self.assertAllEqual(numpy.zeros([2, 2]), load_state.b)
        load_state.a[0, 0] = 42.
        self.assertAllEqual([[42., 1.], [1., 1.]], load_state.a)
        loader.restore(first_save_path).run_restore_ops()
        self.assertAllEqual(numpy.ones([2, 2]), load_state.a)
        loader.restore(second_save_path).run_restore_ops()
        self.assertAllEqual([[1., 1.], [1., 2.]], load_state.a)
        self.assertAllEqual(numpy.zeros([2, 2]), load_state.b)
    def testDocstringExample(self):
        arrays = python_state.NumpyState()
        checkpoint = util.Checkpoint(numpy_arrays=arrays)
        arrays.x = numpy.zeros([3, 4])
        save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
        arrays.x[1, 1] = 4.
        checkpoint.restore(save_path)
        self.assertAllEqual(numpy.zeros([3, 4]), arrays.x)

        second_checkpoint = util.Checkpoint(
            numpy_arrays=python_state.NumpyState())
        second_checkpoint.restore(save_path)
        self.assertAllEqual(numpy.zeros([3, 4]),
                            second_checkpoint.numpy_arrays.x)
 def testNoGraphPollution(self):
     graph = ops.Graph()
     with graph.as_default(), session.Session():
         directory = self.get_temp_dir()
         prefix = os.path.join(directory, "ckpt")
         save_state = python_state.NumpyState()
         saver = util.Checkpoint(numpy=save_state)
         save_state.a = numpy.ones([2, 2])
         save_path = saver.save(prefix)
         saver.restore(save_path)
         graph.finalize()
         saver.save(prefix)
         save_state.a = numpy.zeros([2, 2])
         saver.save(prefix)
         saver.restore(save_path)
 def testNoMixedNumpyStateTF(self):
     save_state = python_state.NumpyState()
     save_state.a = numpy.ones([2, 2])
     with self.assertRaises(NotImplementedError):
         save_state.v = variables.Variable(1.)