def testNames(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") x1 = resource_variable_ops.ResourceVariable(2.) x2 = resource_variable_ops.ResourceVariable(3.) x3 = resource_variable_ops.ResourceVariable(4.) y = resource_variable_ops.ResourceVariable(5.) slots = containers.UniqueNameTracker() slots.track(x1, "x") slots.track(x2, "x") slots.track(x3, "x_1") slots.track(y, "y") self.evaluate((x1.initializer, x2.initializer, x3.initializer, y.initializer)) save_root = checkpointable_utils.Checkpoint(slots=slots) save_path = save_root.save(checkpoint_prefix) restore_slots = checkpointable.Checkpointable() restore_root = checkpointable_utils.Checkpoint( slots=restore_slots) status = restore_root.restore(save_path) restore_slots.x = resource_variable_ops.ResourceVariable(0.) restore_slots.x_1 = resource_variable_ops.ResourceVariable(0.) restore_slots.x_1_1 = resource_variable_ops.ResourceVariable(0.) restore_slots.y = resource_variable_ops.ResourceVariable(0.) status.assert_consumed().run_restore_ops() self.assertEqual(2., self.evaluate(restore_slots.x)) self.assertEqual(3., self.evaluate(restore_slots.x_1)) self.assertEqual(4., self.evaluate(restore_slots.x_1_1)) self.assertEqual(5., self.evaluate(restore_slots.y))
def __init__(self): self.slotdeps = containers.UniqueNameTracker() slotdeps = self.slotdeps slots = [] slots.append(slotdeps.track( resource_variable_ops.ResourceVariable(3.), "x")) slots.append(slotdeps.track( resource_variable_ops.ResourceVariable(4.), "y")) slots.append(slotdeps.track( resource_variable_ops.ResourceVariable(5.), "x")) self.slots = slots
def testLayers(self): tracker = containers.UniqueNameTracker() tracker.track(layers.Dense(3), "dense") tracker.layers[0](array_ops.zeros([1, 1])) self.assertEqual(2, len(tracker.trainable_weights))