コード例 #1
0
 def test_all_nodes(self):
   root = autotrackable.AutoTrackable()
   root.leaf = autotrackable.AutoTrackable()
   root_ckpt = trackable_utils.Checkpoint(root=root)
   root_save_path = root_ckpt.save(
       os.path.join(self.get_temp_dir(), "root_ckpt"))
   all_nodes = checkpoint_view.CheckpointView(root_save_path).descendants()
   self.assertEqual(3, len(all_nodes))
   self.assertEqual(0, all_nodes[0])
   self.assertEqual(1, all_nodes[1])
コード例 #2
0
 def test_all_nodes(self):
   root = base.Trackable()
   leaf = base.Trackable()
   root._track_trackable(leaf, name="leaf")
   root_ckpt = trackable_utils.Checkpoint(root=root)
   root_save_path = root_ckpt.save(
       os.path.join(self.get_temp_dir(), "root_ckpt"))
   all_nodes = checkpoint_view.CheckpointView(root_save_path).descendants()
   self.assertEqual(1, all_nodes[0])
   self.assertEqual(0, all_nodes[1])
コード例 #3
0
 def test_children(self):
   root = autotrackable.AutoTrackable()
   root.leaf = autotrackable.AutoTrackable()
   root_ckpt = trackable_utils.Checkpoint(root=root)
   root_save_path = root_ckpt.save(
       os.path.join(self.get_temp_dir(), "root_ckpt"))
   current_name, node_id = next(
       iter(
           checkpoint_view.CheckpointView(root_save_path).children(0).items()))
   self.assertEqual("leaf", current_name)
   self.assertEqual(1, node_id)
コード例 #4
0
  def test_match_overlapping_nodes(self):
    root1 = autotrackable.AutoTrackable()
    root1.a = root1.b = autotrackable.AutoTrackable()
    root_ckpt = trackable_utils.Checkpoint(root=root1)
    root_save_path = root_ckpt.save(
        os.path.join(self.get_temp_dir(), "root_ckpt"))

    root2 = autotrackable.AutoTrackable()
    a1 = root2.a = autotrackable.AutoTrackable()
    root2.b = autotrackable.AutoTrackable()
    with self.assertLogs(level="WARNING") as logs:
      matching_nodes = checkpoint_view.CheckpointView(root_save_path).match(
          root2)
    self.assertDictEqual(
        matching_nodes,
        {
            0: root2,
            1: a1,
            # Only the first element at the same position will be matched.
        })
    expected_message = (
        "Inconsistent references when matching the checkpoint into this object"
        " graph.")
    self.assertIn(expected_message, logs.output[0])
コード例 #5
0
  def test_match(self):
    root1 = autotrackable.AutoTrackable()
    leaf1 = root1.leaf1 = autotrackable.AutoTrackable()
    leaf2 = root1.leaf2 = autotrackable.AutoTrackable()
    leaf1.leaf3 = autotrackable.AutoTrackable()
    leaf1.leaf4 = autotrackable.AutoTrackable()
    leaf2.leaf5 = autotrackable.AutoTrackable()
    root_ckpt = trackable_utils.Checkpoint(root=root1)
    root_save_path = root_ckpt.save(
        os.path.join(self.get_temp_dir(), "root_ckpt"))

    root2 = autotrackable.AutoTrackable()
    leaf11 = root2.leaf1 = autotrackable.AutoTrackable()
    leaf12 = root2.leaf2 = autotrackable.AutoTrackable()
    leaf13 = leaf11.leaf3 = autotrackable.AutoTrackable()
    leaf15 = leaf12.leaf5 = autotrackable.AutoTrackable()
    matching_nodes = checkpoint_view.CheckpointView(root_save_path).match(root2)
    self.assertDictEqual(matching_nodes, {
        0: root2,
        1: leaf11,
        2: leaf12,
        4: leaf13,
        6: leaf15
    })