Exemplo n.º 1
0
  def testNames(self):
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

    x1 = resource_variable_ops.ResourceVariable(2.)
    x2 = resource_variable_ops.ResourceVariable(3.)
    x3 = resource_variable_ops.ResourceVariable(4.)
    y = resource_variable_ops.ResourceVariable(5.)
    slots = containers.UniqueNameTracker()
    slots.track(x1, "x")
    slots.track(x2, "x")
    slots.track(x3, "x_1")
    slots.track(y, "y")
    self.evaluate((x1.initializer, x2.initializer, x3.initializer,
                   y.initializer))
    save_root = checkpointable_utils.Checkpoint(slots=slots)
    save_path = save_root.save(checkpoint_prefix)

    restore_slots = checkpointable.Checkpointable()
    restore_root = checkpointable_utils.Checkpoint(
        slots=restore_slots)
    status = restore_root.restore(save_path)
    restore_slots.x = resource_variable_ops.ResourceVariable(0.)
    restore_slots.x_1 = resource_variable_ops.ResourceVariable(0.)
    restore_slots.x_1_1 = resource_variable_ops.ResourceVariable(0.)
    restore_slots.y = resource_variable_ops.ResourceVariable(0.)
    status.assert_consumed().run_restore_ops()
    self.assertEqual(2., self.evaluate(restore_slots.x))
    self.assertEqual(3., self.evaluate(restore_slots.x_1))
    self.assertEqual(4., self.evaluate(restore_slots.x_1_1))
    self.assertEqual(5., self.evaluate(restore_slots.y))
Exemplo n.º 2
0
 def __init__(self):
   self.slotdeps = containers.UniqueNameTracker()
   slotdeps = self.slotdeps
   slots = []
   slots.append(slotdeps.track(
       resource_variable_ops.ResourceVariable(3.), "x"))
   slots.append(slotdeps.track(
       resource_variable_ops.ResourceVariable(4.), "y"))
   slots.append(slotdeps.track(
       resource_variable_ops.ResourceVariable(5.), "x"))
   self.slots = slots
Exemplo n.º 3
0
 def testLayers(self):
   tracker = containers.UniqueNameTracker()
   tracker.track(layers.Dense(3), "dense")
   tracker.layers[0](array_ops.zeros([1, 1]))
   self.assertEqual(2, len(tracker.trainable_weights))