Beispiel #1
0
 def testNestedLists(self):
     a = tracking.AutoTrackable()
     a.l = []
     b = tracking.AutoTrackable()
     a.l.append([b])
     c = tracking.AutoTrackable()
     a.l[0].append(c)
     a_deps = util.list_objects(a)
     self.assertIn(b, a_deps)
     self.assertIn(c, a_deps)
     a.l[0].append(1)
     d = tracking.AutoTrackable()
     a.l[0].append(d)
     a_deps = util.list_objects(a)
     self.assertIn(d, a_deps)
     self.assertIn(b, a_deps)
     self.assertIn(c, a_deps)
     self.assertNotIn(1, a_deps)
     e = tracking.AutoTrackable()
     f = tracking.AutoTrackable()
     a.l1 = [[], [e]]
     a.l1[0].append(f)
     a_deps = util.list_objects(a)
     self.assertIn(e, a_deps)
     self.assertIn(f, a_deps)
     checkpoint = util.Checkpoint(a=a)
     checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
     a.l[0].append(data_structures.NoDependency([]))
     a.l[0][-1].append(5)
     checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
     # Dirtying the inner list means the root object is unsaveable.
     a.l[0][1] = 2
     with self.assertRaisesRegexp(ValueError,
                                  "A list element was replaced"):
         checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
 def testNestedLists(self):
   a = tracking.AutoTrackable()
   a.l = []
   b = tracking.AutoTrackable()
   a.l.append([b])
   c = tracking.AutoTrackable()
   a.l[0].append(c)
   a_deps = util.list_objects(a)
   self.assertIn(b, a_deps)
   self.assertIn(c, a_deps)
   a.l[0].append(1)
   d = tracking.AutoTrackable()
   a.l[0].append(d)
   a_deps = util.list_objects(a)
   self.assertIn(d, a_deps)
   self.assertIn(b, a_deps)
   self.assertIn(c, a_deps)
   self.assertNotIn(1, a_deps)
   e = tracking.AutoTrackable()
   f = tracking.AutoTrackable()
   a.l1 = [[], [e]]
   a.l1[0].append(f)
   a_deps = util.list_objects(a)
   self.assertIn(e, a_deps)
   self.assertIn(f, a_deps)
   checkpoint = util.Checkpoint(a=a)
   checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
   a.l[0].append(data_structures.NoDependency([]))
   a.l[0][-1].append(5)
   checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
   # Dirtying the inner list means the root object is unsaveable.
   a.l[0][1] = 2
   with self.assertRaisesRegexp(ValueError, "A list element was replaced"):
     checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
Beispiel #3
0
 def testAddVariableOverwrite(self):
   root = base.Trackable()
   a = root._add_variable_with_custom_getter(
       name="v", shape=[], getter=variable_scope.get_variable)
   self.assertEqual([root, a], util.list_objects(root))
   with ops.Graph().as_default():
     b = root._add_variable_with_custom_getter(
         name="v", shape=[], overwrite=True,
         getter=variable_scope.get_variable)
     self.assertEqual([root, b], util.list_objects(root))
   with ops.Graph().as_default():
     with self.assertRaisesRegex(ValueError,
                                 "already declared as a dependency"):
       root._add_variable_with_custom_getter(
           name="v", shape=[], overwrite=False,
           getter=variable_scope.get_variable)
Beispiel #4
0
    def test_model(self):
        num_classes = 5
        model = build_deeplab_v3_plus(classes=num_classes,
                                      bone_arch='resnet_50',
                                      bone_init='imagenet',
                                      bone_train=False,
                                      aspp_filters=8,
                                      aspp_stride=16,
                                      low_filters=16,
                                      decoder_filters=4)
        model.compile(optimizer='sgd',
                      loss='sparse_categorical_crossentropy',
                      run_eagerly=test_utils.should_run_eagerly())
        model.fit(np.random.random((2, 224, 224, 3)).astype(np.uint8),
                  np.random.randint(0, num_classes, (2, 224, 224)),
                  epochs=1,
                  batch_size=10)

        # test config
        model.get_config()

        # check whether the model variables are present
        # in the trackable list of objects
        checkpointed_objects = object_identity.ObjectIdentitySet(
            trackable_util.list_objects(model))
        for v in model.variables:
            self.assertIn(v, checkpointed_objects)
Beispiel #5
0
 def testAddVariableOverwrite(self):
   root = base.Trackable()
   a = root._add_variable_with_custom_getter(
       name="v", shape=[], getter=variable_scope.get_variable)
   self.assertEqual([root, a], util.list_objects(root))
   with ops.Graph().as_default():
     b = root._add_variable_with_custom_getter(
         name="v", shape=[], overwrite=True,
         getter=variable_scope.get_variable)
     self.assertEqual([root, b], util.list_objects(root))
   with ops.Graph().as_default():
     with self.assertRaisesRegexp(
         ValueError, "already declared as a dependency"):
       root._add_variable_with_custom_getter(
           name="v", shape=[], overwrite=False,
           getter=variable_scope.get_variable)
Beispiel #6
0
    def test_model(self):
        model = build_fba_matting()
        model.compile(
            optimizer='sgd', loss=['mse', None, None, None],
            run_eagerly=test_utils.should_run_eagerly())
        model.fit(
            [
                np.random.random((2, 240, 240, 3)).astype(np.uint8),
                np.random.random((2, 240, 240, 2)).astype(np.uint8),
                np.random.random((2, 240, 240, 6)).astype(np.uint8),
            ],
            [
                np.random.random((2, 240, 240, 7)).astype(np.float32),
                np.random.random((2, 240, 240, 1)).astype(np.float32),
                np.random.random((2, 240, 240, 3)).astype(np.float32),
                np.random.random((2, 240, 240, 3)).astype(np.float32)
            ],
            epochs=1, batch_size=10)

        # test config
        model.get_config()

        # check whether the model variables are present
        # in the trackable list of objects
        checkpointed_objects = object_identity.ObjectIdentitySet(trackable_util.list_objects(model))
        for v in model.variables:
            self.assertIn(v, checkpointed_objects)
    def test_timedistributed_dense(self):
        model = keras.models.Sequential()
        model.add(
            keras.layers.TimeDistributed(
                keras.layers.Dense(2), input_shape=(3, 4)
            )
        )
        model.compile(optimizer="rmsprop", loss="mse")
        model.fit(
            np.random.random((10, 3, 4)),
            np.random.random((10, 3, 2)),
            epochs=1,
            batch_size=10,
        )

        # test config
        model.get_config()

        # check whether the model variables are present in the
        # trackable list of objects
        checkpointed_object_ids = {
            id(o) for o in trackable_util.list_objects(model)
        }
        for v in model.variables:
            self.assertIn(id(v), checkpointed_object_ids)
    def test_trackable_save_restore(self):
        with self.test_session():

            def _templated():
                v = tf.compat.v1.get_variable(
                    "v",
                    shape=[1],
                    initializer=tf.compat.v1.zeros_initializer(),
                    use_resource=True)
                v2 = tf.compat.v1.get_variable(
                    "v2",
                    shape=[1],
                    initializer=tf.compat.v1.zeros_initializer(),
                    use_resource=True)
                manual = _ManualScope()
                return v, v + 1., v2, manual, manual()

            save_template = tf.compat.v1.make_template("s1", _templated)
            v1_save, _, v2_save, manual_scope, manual_scope_v = save_template()
            self.assertEqual(
                set([
                    id(v1_save),
                    id(v2_save),
                    id(manual_scope),
                    id(manual_scope_v),
                    id(save_template)
                ]), set(map(id, trackable_utils.list_objects(save_template))))
            self.assertDictEqual({"in_manual_scope": manual_scope_v},
                                 manual_scope._trackable_children())
            optimizer = adam.Adam(0.0)
            save_root = tf.train.Checkpoint(my_template=save_template,
                                            optimizer=optimizer)
            optimizer.minimize(v1_save.read_value, var_list=[v1_save])
            self.evaluate([v.initializer for v in save_template.variables])
            optimizer_variables = optimizer.variables() + list(
                optimizer._hyper.values())
            self.evaluate([v.initializer for v in optimizer_variables])
            self.evaluate(v1_save.assign([12.]))
            self.evaluate(v2_save.assign([14.]))
            checkpoint_directory = self.get_temp_dir()
            checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
            save_path = save_root.save(checkpoint_prefix)

            load_template = tf.compat.v1.make_template("s2", _templated)
            load_optimizer = adam.Adam(0.0)
            load_root = tf.train.Checkpoint(my_template=load_template,
                                            optimizer=load_optimizer)
            status = load_root.restore(save_path)
            var, var_plus_one, var2, _, _ = load_template()
            load_optimizer.minimize(var.read_value, var_list=[var])

            children = load_template._trackable_children()
            self.assertEqual({"v", "v2", "ManualScope"}, children.keys())
            status.assert_consumed().run_restore_ops()
            self.assertAllEqual([12.], self.evaluate(var))
            self.assertAllEqual([13.], self.evaluate(var_plus_one))
            self.assertAllEqual([14.], self.evaluate(var2))
Beispiel #9
0
 def testDictWrapperNoDependency(self):
     a = tracking.AutoTrackable()
     a.d = data_structures.NoDependency({})
     a.d[1] = [3]
     self.assertEqual([a], util.list_objects(a))
     model = training.Model()
     model.sub = a
     save_path = os.path.join(self.get_temp_dir(), "ckpt")
     model.save_weights(save_path)
     model.load_weights(save_path)
 def testDictWrapperNoDependency(self):
   a = tracking.AutoTrackable()
   a.d = data_structures.NoDependency({})
   a.d[1] = [3]
   self.assertEqual([a], util.list_objects(a))
   model = training.Model()
   model.sub = a
   save_path = os.path.join(self.get_temp_dir(), "ckpt")
   model.save_weights(save_path)
   model.load_weights(save_path)
 def testNonStringKeyNotTrackableValue(self):
   a = tracking.AutoTrackable()
   a.d = {}
   a.d["a"] = [3]
   a.d[1] = data_structures.NoDependency([3])
   self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
   model = training.Model()
   model.sub = a
   save_path = os.path.join(self.get_temp_dir(), "ckpt")
   model.save_weights(save_path)
   model.load_weights(save_path)
Beispiel #12
0
 def testNonStringKeyNotTrackableValue(self):
     a = tracking.AutoTrackable()
     a.d = {}
     a.d["a"] = [3]
     a.d[1] = data_structures.NoDependency([3])
     self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
     model = training.Model()
     model.sub = a
     save_path = os.path.join(self.get_temp_dir(), "ckpt")
     model.save_weights(save_path)
     model.load_weights(save_path)
Beispiel #13
0
  def test_trackable_save_restore(self):

    def _templated():
      v = variable_scope.get_variable(
          "v", shape=[1], initializer=init_ops.zeros_initializer(),
          use_resource=True)
      v2 = variable_scope.get_variable(
          "v2", shape=[1], initializer=init_ops.zeros_initializer(),
          use_resource=True)
      manual = _ManualScope()
      return v, v + 1., v2, manual, manual()

    save_template = template.make_template("s1", _templated)
    v1_save, _, v2_save, manual_scope, manual_scope_v = save_template()
    six.assertCountEqual(
        self,
        [id(v1_save), id(v2_save), id(manual_scope),
         id(manual_scope_v), id(save_template)],
        map(id, trackable_utils.list_objects(save_template)))
    manual_dep, = manual_scope._checkpoint_dependencies
    self.assertEqual("in_manual_scope", manual_dep.name)
    self.assertIs(manual_scope_v, manual_dep.ref)
    optimizer = adam.Adam(0.0)
    save_root = trackable_utils.Checkpoint(
        my_template=save_template, optimizer=optimizer)
    optimizer.minimize(v1_save.read_value,
                       var_list=[v1_save])
    self.evaluate([v.initializer for v in save_template.variables])
    optimizer_variables = optimizer.variables() + list(
        optimizer._hyper.values())
    self.evaluate([v.initializer for v in optimizer_variables])
    self.evaluate(v1_save.assign([12.]))
    self.evaluate(v2_save.assign([14.]))
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
    save_path = save_root.save(checkpoint_prefix)

    load_template = template.make_template("s2", _templated)
    load_optimizer = adam.Adam(0.0)
    load_root = trackable_utils.Checkpoint(
        my_template=load_template, optimizer=load_optimizer)
    status = load_root.restore(save_path)
    var, var_plus_one, var2, _, _ = load_template()
    load_optimizer.minimize(var.read_value, var_list=[var])
    self.assertLen(load_template._checkpoint_dependencies, 3)
    self.assertEqual("v", load_template._checkpoint_dependencies[0].name)
    self.assertEqual("v2", load_template._checkpoint_dependencies[1].name)
    self.assertEqual("ManualScope",
                     load_template._checkpoint_dependencies[2].name)
    status.assert_consumed().run_restore_ops()
    self.assertAllEqual([12.], self.evaluate(var))
    self.assertAllEqual([13.], self.evaluate(var_plus_one))
    self.assertAllEqual([14.], self.evaluate(var2))
Beispiel #14
0
 def testShallowCopyTrackable(self):
     original = tracking.AutoTrackable()
     original_sub = tracking.AutoTrackable()
     original.a = [[1.]]
     original.b = {"a": original_sub}
     shallow_copied = copy.copy(original)
     self.assertIs(original_sub, shallow_copied.b["a"])
     self.assertIsNot(original, shallow_copied)
     self.assertEqual([[1.]], shallow_copied.a)
     shallow_deps = util.list_objects(shallow_copied)
     self.assertIn(shallow_copied.a, shallow_deps)
     self.assertIn(shallow_copied.b, shallow_deps)
     self.assertIn(shallow_copied.b["a"], shallow_deps)
 def testShallowCopyTrackable(self):
   original = tracking.AutoTrackable()
   original_sub = tracking.AutoTrackable()
   original.a = [[1.]]
   original.b = {"a": original_sub}
   shallow_copied = copy.copy(original)
   self.assertIs(original_sub, shallow_copied.b["a"])
   self.assertIsNot(original, shallow_copied)
   self.assertEqual([[1.]], shallow_copied.a)
   shallow_deps = util.list_objects(shallow_copied)
   self.assertIn(shallow_copied.a, shallow_deps)
   self.assertIn(shallow_copied.b, shallow_deps)
   self.assertIn(shallow_copied.b["a"], shallow_deps)
 def testListBasic(self):
     a = autotrackable.AutoTrackable()
     b = autotrackable.AutoTrackable()
     a.l = [b]
     c = autotrackable.AutoTrackable()
     a.l.append(c)
     a_deps = util.list_objects(a)
     self.assertIn(b, a_deps)
     self.assertIn(c, a_deps)
     self.assertIn("l", a._trackable_children())
     direct_a_dep = a._trackable_children()["l"]
     self.assertIn(b, direct_a_dep)
     self.assertIn(c, direct_a_dep)
 def testListBasic(self):
   a = tracking.AutoTrackable()
   b = tracking.AutoTrackable()
   a.l = [b]
   c = tracking.AutoTrackable()
   a.l.append(c)
   a_deps = util.list_objects(a)
   self.assertIn(b, a_deps)
   self.assertIn(c, a_deps)
   direct_a_dep, = a._checkpoint_dependencies
   self.assertEqual("l", direct_a_dep.name)
   self.assertIn(b, direct_a_dep.ref)
   self.assertIn(c, direct_a_dep.ref)
Beispiel #18
0
 def testListBasic(self):
     a = tracking.AutoTrackable()
     b = tracking.AutoTrackable()
     a.l = [b]
     c = tracking.AutoTrackable()
     a.l.append(c)
     a_deps = util.list_objects(a)
     self.assertIn(b, a_deps)
     self.assertIn(c, a_deps)
     direct_a_dep, = a._checkpoint_dependencies
     self.assertEqual("l", direct_a_dep.name)
     self.assertIn(b, direct_a_dep.ref)
     self.assertIn(c, direct_a_dep.ref)
  def test_trackable_save_restore(self):

    def _templated():
      v = variable_scope.get_variable(
          "v", shape=[1], initializer=init_ops.zeros_initializer(),
          use_resource=True)
      v2 = variable_scope.get_variable(
          "v2", shape=[1], initializer=init_ops.zeros_initializer(),
          use_resource=True)
      manual = _ManualScope()
      return v, v + 1., v2, manual, manual()

    save_template = template.make_template("s1", _templated)
    v1_save, _, v2_save, manual_scope, manual_scope_v = save_template()
    six.assertCountEqual(
        self,
        [v1_save, v2_save, manual_scope, manual_scope_v, save_template],
        trackable_utils.list_objects(save_template))
    manual_dep, = manual_scope._checkpoint_dependencies
    self.assertEqual("in_manual_scope", manual_dep.name)
    self.assertIs(manual_scope_v, manual_dep.ref)
    optimizer = adam.AdamOptimizer(0.0)
    save_root = trackable_utils.Checkpoint(
        my_template=save_template, optimizer=optimizer)
    optimizer.minimize(v1_save.read_value)
    self.evaluate([v.initializer for v in save_template.variables])
    self.evaluate([v.initializer for v in optimizer.variables()])
    self.evaluate(v1_save.assign([12.]))
    self.evaluate(v2_save.assign([14.]))
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
    save_path = save_root.save(checkpoint_prefix)

    load_template = template.make_template("s2", _templated)
    load_optimizer = adam.AdamOptimizer(0.0)
    load_root = trackable_utils.Checkpoint(
        my_template=load_template, optimizer=load_optimizer)
    status = load_root.restore(save_path)
    var, var_plus_one, var2, _, _ = load_template()
    load_optimizer.minimize(var.read_value)
    self.assertEqual(3, len(load_template._checkpoint_dependencies))
    self.assertEqual("v", load_template._checkpoint_dependencies[0].name)
    self.assertEqual("v2", load_template._checkpoint_dependencies[1].name)
    self.assertEqual("ManualScope",
                     load_template._checkpoint_dependencies[2].name)
    status.assert_consumed().run_restore_ops()
    self.assertAllEqual([12.], self.evaluate(var))
    self.assertAllEqual([13.], self.evaluate(var_plus_one))
    self.assertAllEqual([14.], self.evaluate(var2))
Beispiel #20
0
 def testDeepCopyTrackable(self):
     original = tracking.AutoTrackable()
     original_sub = tracking.AutoTrackable()
     original.a = [[1.]]
     original.b = {"a": original_sub}
     deep_copied = copy.deepcopy(original)
     self.assertIsNot(original, deep_copied)
     self.assertIsNot(original_sub, deep_copied.b["a"])
     self.assertEqual([[1.]], deep_copied.a)
     self.assertIsInstance(deep_copied.b["a"], tracking.AutoTrackable)
     deps = util.list_objects(deep_copied)
     self.assertIn(deep_copied.a, deps)
     self.assertIn(deep_copied.b, deps)
     self.assertIn(deep_copied.b["a"], deps)
     self.assertNotIn(original_sub, deps)
Beispiel #21
0
    def test_model(self):
        model = build_uper_net(classes=2, bone_arch='swin_tiny_224', bone_init='imagenet', bone_train=False)
        model.compile(optimizer='sgd', loss='mse', run_eagerly=test_utils.should_run_eagerly())
        model.fit(
            np.random.random((2, 240, 240, 3)).astype(np.uint8),
            np.random.random((2, 240, 240, 2)).astype(np.float32),
            epochs=1, batch_size=10)

        # test config
        model.get_config()

        # check whether the model variables are present
        # in the trackable list of objects
        checkpointed_objects = object_identity.ObjectIdentitySet(trackable_util.list_objects(model))
        for v in model.variables:
            self.assertIn(v, checkpointed_objects)
Beispiel #22
0
  def test_trackable_save_restore(self):

    def _templated():
      v = variable_scope.get_variable(
          "v", shape=[1], initializer=init_ops.zeros_initializer(),
          use_resource=True)
      v2 = variable_scope.get_variable(
          "v2", shape=[1], initializer=init_ops.zeros_initializer(),
          use_resource=True)
      manual = _ManualScope()
      return v, v + 1., v2, manual, manual()

    save_template = template.make_template("s1", _templated)
    v1_save, _, v2_save, manual_scope, manual_scope_v = save_template()
    six.assertCountEqual(self, [
        id(obj) for obj in
        [v1_save, v2_save, manual_scope, manual_scope_v, save_template]
    ], [id(obj) for obj in trackable_utils.list_objects(save_template)])
    self.assertDictEqual({"in_manual_scope": manual_scope_v},
                         manual_scope._trackable_children())
    optimizer = adam.AdamOptimizer(0.0)
    save_root = trackable_utils.Checkpoint(
        my_template=save_template, optimizer=optimizer)
    optimizer.minimize(v1_save.read_value)
    self.evaluate([v.initializer for v in save_template.variables])
    self.evaluate([v.initializer for v in optimizer.variables()])
    self.evaluate(v1_save.assign([12.]))
    self.evaluate(v2_save.assign([14.]))
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
    save_path = save_root.save(checkpoint_prefix)

    load_template = template.make_template("s2", _templated)
    load_optimizer = adam.AdamOptimizer(0.0)
    load_root = trackable_utils.Checkpoint(
        my_template=load_template, optimizer=load_optimizer)
    status = load_root.restore(save_path)
    var, var_plus_one, var2, _, _ = load_template()
    load_optimizer.minimize(var.read_value)
    self.assertEqual(3, len(load_template._trackable_children()))
    self.assertEqual(set(["v", "v2", "ManualScope"]),
                     load_template._trackable_children().keys())
    status.assert_consumed().run_restore_ops()
    self.assertAllEqual([12.], self.evaluate(var))
    self.assertAllEqual([13.], self.evaluate(var_plus_one))
    self.assertAllEqual([14.], self.evaluate(var2))
 def testDeepCopyTrackable(self):
   original = tracking.AutoTrackable()
   original_sub = tracking.AutoTrackable()
   original.a = [[1.]]
   original.b = {"a": original_sub}
   self.assertIsInstance(original.b, dict)
   deep_copied = copy.deepcopy(original)
   self.assertIsInstance(deep_copied.b, dict)
   self.assertIsNot(original, deep_copied)
   self.assertIsNot(original_sub, deep_copied.b["a"])
   self.assertEqual([[1.]], deep_copied.a)
   self.assertIsInstance(deep_copied.b["a"], tracking.AutoTrackable)
   deps = util.list_objects(deep_copied)
   self.assertIn(deep_copied.a, deps)
   self.assertIn(deep_copied.b, deps)
   self.assertIn(deep_copied.b["a"], deps)
   self.assertNotIn(original_sub, deps)
  def test_trackable_dependencies(self):
    rnn = keras.layers.SimpleRNN
    x = np.random.random((2, 2, 2))
    y = np.random.random((2, 2))
    model = keras.models.Sequential()
    model.add(rnn(2))
    model.compile(
        optimizer='rmsprop',
        loss='mse',
        run_eagerly=testing_utils.should_run_eagerly())
    model.fit(x, y, epochs=1, batch_size=1)

    # check whether the model variables are present in the
    # trackable list of objects
    checkpointed_objects = set(trackable_util.list_objects(model))
    for v in model.variables:
      self.assertIn(v, checkpointed_objects)
Beispiel #25
0
  def test_trackable_dependencies(self):
    rnn = keras.layers.SimpleRNN
    x = np.random.random((2, 2, 2))
    y = np.random.random((2, 2))
    model = keras.models.Sequential()
    model.add(rnn(2))
    model.compile(
        optimizer='rmsprop',
        loss='mse',
        run_eagerly=testing_utils.should_run_eagerly())
    model.fit(x, y, epochs=1, batch_size=1)

    # check whether the model variables are present in the
    # trackable list of objects
    checkpointed_objects = object_identity.ObjectIdentitySet(
        trackable_util.list_objects(model))
    for v in model.variables:
      self.assertIn(v, checkpointed_objects)
Beispiel #26
0
    def test_weight_norm_dense(self):
        model = tf.keras.models.Sequential()
        model.add(WeightNorm(tf.keras.layers.Dense(2), input_shape=(3, 4)))
        model.compile(optimizer='rmsprop',
                      loss='mse',
                      run_eagerly=testing_utils.should_run_eagerly())
        model.fit(np.random.random((10, 3, 4)),
                  np.random.random((10, 3, 2)),
                  epochs=1,
                  batch_size=10)

        # test config
        model.get_config()

        # check whether the model variables are present in the trackable list of objects
        checkpointed_objects = object_identity.ObjectIdentitySet(
            trackable_util.list_objects(model))
        for v in model.variables:
            self.assertIn(v, checkpointed_objects)
Beispiel #27
0
    def test_custom_backward_layer(self, mode):
        rnn = keras.layers.SimpleRNN
        samples = 2
        dim = 2
        timesteps = 2
        output_dim = 2

        x = np.random.random((samples, timesteps, dim))
        target_dim = 2 * output_dim if mode == "concat" else output_dim
        y = np.random.random((samples, target_dim))
        forward_layer = rnn(output_dim)
        backward_layer = rnn(output_dim, go_backwards=True)

        # test with Sequential model
        model = keras.models.Sequential()
        model.add(
            keras.layers.Bidirectional(
                forward_layer,
                merge_mode=mode,
                backward_layer=backward_layer,
                input_shape=(timesteps, dim),
            ))
        model.compile(optimizer="rmsprop", loss="mse")
        model.fit(x, y, epochs=1, batch_size=1)

        # check whether the model variables are present in the
        # trackable list of objects
        checkpointed_object_ids = {
            id(o)
            for o in trackable_util.list_objects(model)
        }
        for v in model.variables:
            self.assertIn(id(v), checkpointed_object_ids)

        # test compute output shape
        ref_shape = model.layers[-1].output.shape
        shape = model.layers[-1].compute_output_shape((None, timesteps, dim))
        self.assertListEqual(shape.as_list(), ref_shape.as_list())

        # test config
        model.get_config()
        model = keras.models.model_from_json(model.to_json())
        model.summary()
Beispiel #28
0
 def testDictionariesBasic(self):
     a = training.Model()
     b = training.Model()
     a.attribute = {"b": b}
     c = training.Model()
     a.attribute["c"] = []
     a.attribute["c"].append(c)
     a_deps = util.list_objects(a)
     self.assertIn(b, a_deps)
     self.assertIn(c, a_deps)
     self.assertIs(b, a.attribute["b"])
     self.assertEqual({"b", "c"}, a.attribute._trackable_children().keys())
     self.assertEqual([b, c], a.layers)
     self.assertEqual([b, c], a.attribute.layers)
     self.assertEqual([c], a.attribute["c"].layers)
     checkpoint = tf.train.Checkpoint(a=a)
     save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
     with self.cached_session():
         checkpoint.restore(
             save_path).assert_consumed().initialize_or_restore()
Beispiel #29
0
    def test_model(self):
        model = tf.keras.models.Sequential()
        model.add(
            TemporalConvNet(filters=[5, 4, 3], kernel_size=3, dropout=0.2))
        model.compile(optimizer='rmsprop',
                      loss='mse',
                      run_eagerly=testing_utils.should_run_eagerly())
        model.fit(np.random.random((10, 3, 4)),
                  np.random.random((10, 3, 3)),
                  epochs=1,
                  batch_size=10)

        # test config
        model.get_config()

        # check whether the model variables are present in the trackable list of objects
        checkpointed_objects = object_identity.ObjectIdentitySet(
            trackable_util.list_objects(model))
        for v in model.variables:
            self.assertIn(v, checkpointed_objects)
Beispiel #30
0
    def test_model(self):
        num_classes = 1
        model = build_dexi_ned(num_classes)
        model.compile(optimizer='sgd',
                      loss='binary_crossentropy',
                      run_eagerly=test_utils.should_run_eagerly())
        model.fit(np.random.random((2, 224, 224, 3)).astype(np.uint8),
                  np.random.randint(0, num_classes, (2, 224, 224)),
                  epochs=1,
                  batch_size=1)

        # test config
        model.get_config()

        # check whether the model variables are present
        # in the trackable list of objects
        checkpointed_objects = object_identity.ObjectIdentitySet(
            trackable_util.list_objects(model))
        for v in model.variables:
            self.assertIn(v, checkpointed_objects)
 def testNonAppendNotTrackable(self):
   # Non-append mutations (deleting or overwriting values) are OK when the
   # values aren't tracked.
   a = tracking.AutoTrackable()
   a.d = {}
   a.d["a"] = [3]
   a.d[1] = 3
   a.d[1] = 2
   self.assertEqual(2, a.d[1])
   del a.d[1]
   a.d[2] = data_structures.NoDependency(tracking.AutoTrackable())
   second = tracking.AutoTrackable()
   a.d[2] = data_structures.NoDependency(second)
   self.assertIs(second, a.d[2])
   self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
   model = training.Model()
   model.sub = a
   save_path = os.path.join(self.get_temp_dir(), "ckpt")
   model.save_weights(save_path)
   model.load_weights(save_path)
Beispiel #32
0
 def testNonAppendNotTrackable(self):
     # Non-append mutations (deleting or overwriting values) are OK when the
     # values aren't tracked.
     a = tracking.AutoTrackable()
     a.d = {}
     a.d["a"] = [3]
     a.d[1] = 3
     a.d[1] = 2
     self.assertEqual(2, a.d[1])
     del a.d[1]
     a.d[2] = data_structures.NoDependency(tracking.AutoTrackable())
     second = tracking.AutoTrackable()
     a.d[2] = data_structures.NoDependency(second)
     self.assertIs(second, a.d[2])
     self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
     model = training.Model()
     model.sub = a
     save_path = os.path.join(self.get_temp_dir(), "ckpt")
     model.save_weights(save_path)
     model.load_weights(save_path)
Beispiel #33
0
    def testNoDependency(self):
        root = tf.Module()
        hasdep = tf.Module()
        root.hasdep = hasdep
        nodep = tf.Module()
        root.nodep = data_structures.NoDependency(nodep)
        self.assertLen(root._trackable_children(), 1)
        self.assertIs(root._trackable_children()["hasdep"], root.hasdep)
        self.assertIs(root.hasdep, hasdep)
        self.assertIs(root.nodep, nodep)

        class NoDependencyModel(training.Model):
            @tf.__internal__.tracking.no_automatic_dependency_tracking
            def __init__(self):
                super(NoDependencyModel, self).__init__()
                self.a = []
                self.b = tf.Module()

        nodeps = NoDependencyModel()
        self.assertEqual([nodeps], util.list_objects(nodeps))
Beispiel #34
0
    def testNoDependency(self):
        root = module.Module()
        hasdep = module.Module()
        root.hasdep = hasdep
        nodep = module.Module()
        root.nodep = data_structures.NoDependency(nodep)
        self.assertEqual(1, len(root._checkpoint_dependencies))
        self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep)
        self.assertIs(root.hasdep, hasdep)
        self.assertIs(root.nodep, nodep)

        class NoDependencyModel(training.Model):
            @base.no_automatic_dependency_tracking
            def __init__(self):
                super(NoDependencyModel, self).__init__()
                self.a = []
                self.b = module.Module()

        nodeps = NoDependencyModel()
        self.assertEqual([nodeps], util.list_objects(nodeps))
  def testNoDependency(self):
    root = tracking.AutoTrackable()
    hasdep = tracking.AutoTrackable()
    root.hasdep = hasdep
    nodep = tracking.AutoTrackable()
    root.nodep = data_structures.NoDependency(nodep)
    self.assertEqual(1, len(root._checkpoint_dependencies))
    self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep)
    self.assertIs(root.hasdep, hasdep)
    self.assertIs(root.nodep, nodep)

    class NoDependencyModel(training.Model):

      @base.no_automatic_dependency_tracking
      def __init__(self):
        super(NoDependencyModel, self).__init__()
        self.a = []
        self.b = tracking.AutoTrackable()

    nodeps = NoDependencyModel()
    self.assertEqual([nodeps], util.list_objects(nodeps))
  def test_timedistributed_dense(self):
    model = keras.models.Sequential()
    model.add(
        keras.layers.TimeDistributed(
            keras.layers.Dense(2), input_shape=(3, 4)))
    model.compile(optimizer='rmsprop', loss='mse')
    model.fit(
        np.random.random((10, 3, 4)),
        np.random.random((10, 3, 2)),
        epochs=1,
        batch_size=10)

    # test config
    model.get_config()

    # check whether the model variables are present in the
    # trackable list of objects
    checkpointed_objects = object_identity.ObjectIdentitySet(
        trackable_util.list_objects(model))
    for v in model.variables:
      self.assertIn(v, checkpointed_objects)
Beispiel #37
0
  def test_custom_backward_layer(self, mode):
    rnn = keras.layers.SimpleRNN
    samples = 2
    dim = 2
    timesteps = 2
    output_dim = 2

    x = np.random.random((samples, timesteps, dim))
    target_dim = 2 * output_dim if mode == 'concat' else output_dim
    y = np.random.random((samples, target_dim))
    forward_layer = rnn(output_dim)
    backward_layer = rnn(output_dim, go_backwards=True)

    # test with Sequential model
    model = keras.models.Sequential()
    model.add(
        keras.layers.Bidirectional(
            forward_layer,
            merge_mode=mode,
            backward_layer=backward_layer,
            input_shape=(timesteps, dim)))
    model.compile(optimizer='rmsprop', loss='mse')
    model.fit(x, y, epochs=1, batch_size=1)

    # check whether the model variables are present in the
    # trackable list of objects
    checkpointed_objects = object_identity.ObjectIdentitySet(
        trackable_util.list_objects(model))
    for v in model.variables:
      self.assertIn(v, checkpointed_objects)

    # test compute output shape
    ref_shape = model.layers[-1].output.shape
    shape = model.layers[-1].compute_output_shape((None, timesteps, dim))
    self.assertListEqual(shape.as_list(), ref_shape.as_list())

    # test config
    model.get_config()
    model = keras.models.model_from_json(model.to_json())
    model.summary()
Beispiel #38
0
    def testListDeepCopy(self):
        root = tracking.AutoTrackable()
        orig_list = [[1.]]
        root.a = orig_list
        copied = copy.deepcopy(root.a)
        self.assertAllEqual([[1.]], copied)
        self.assertIsNot(root.a, copied)
        self.assertIsNot(root.a[0], copied[0])

        # Dirtiness should be inherited
        util.list_objects(root.a)
        orig_list.append(1.)
        with self.assertRaises(ValueError):
            util.list_objects(root.a)
        with self.assertRaises(ValueError):
            util.list_objects(copy.deepcopy(root.a))
Beispiel #39
0
    def testDictDeepCopy(self):
        root = tracking.AutoTrackable()
        orig_dict = {"a": [1.]}
        root.a = orig_dict
        copied = copy.deepcopy(root.a)
        self.assertAllEqual([1.], copied["a"])
        self.assertIsNot(root.a, copied)
        self.assertIsNot(root.a["a"], copied["a"])

        # Dirtiness should be inherited
        util.list_objects(root.a)
        orig_dict["b"] = []
        with self.assertRaises(ValueError):
            util.list_objects(root.a)
        with self.assertRaises(ValueError):
            util.list_objects(copy.deepcopy(root.a))
  def testListDeepCopy(self):
    root = tracking.AutoTrackable()
    orig_list = [[1.]]
    root.a = orig_list
    copied = copy.deepcopy(root.a)
    self.assertAllEqual([[1.]], copied)
    self.assertIsNot(root.a, copied)
    self.assertIsNot(root.a[0], copied[0])

    # Dirtiness should be inherited
    util.list_objects(root.a)
    orig_list.append(1.)
    with self.assertRaises(ValueError):
      util.list_objects(root.a)
    with self.assertRaises(ValueError):
      util.list_objects(copy.deepcopy(root.a))
  def testDictDeepCopy(self):
    root = tracking.AutoTrackable()
    orig_dict = {"a": [1.]}
    root.a = orig_dict
    copied = copy.deepcopy(root.a)
    self.assertAllEqual([1.], copied["a"])
    self.assertIsNot(root.a, copied)
    self.assertIsNot(root.a["a"], copied["a"])

    # Dirtiness should be inherited
    util.list_objects(root.a)
    orig_dict["b"] = []
    with self.assertRaises(ValueError):
      util.list_objects(root.a)
    with self.assertRaises(ValueError):
      util.list_objects(copy.deepcopy(root.a))
Beispiel #42
0
    def test_model(self):
        num_classes = 5
        model = build_deeplab_v3_plus_with_point_rend(classes=num_classes,
                                                      bone_arch='resnet_50',
                                                      bone_init='imagenet',
                                                      bone_train=False,
                                                      aspp_filters=8,
                                                      aspp_stride=16,
                                                      low_filters=16,
                                                      decoder_filters=4,
                                                      rend_strides=(2, 4),
                                                      rend_units=(256, ),
                                                      rend_points=(0.1697,
                                                                   0.0005),
                                                      rend_oversample=3,
                                                      rend_importance=0.75,
                                                      rend_weights=True,
                                                      rend_corners=True)
        model.compile(optimizer='sgd',
                      loss=['sparse_categorical_crossentropy', None],
                      run_eagerly=test_utils.should_run_eagerly())
        model.fit(
            {
                'image': np.random.random((2, 224, 224, 3)).astype(np.uint8),
                'label': np.random.randint(0, num_classes, (2, 224, 224)),
                'weight': np.random.rand(2, 224, 224),
            },
            np.random.randint(0, num_classes, (2, 224, 224)),
            epochs=1,
            batch_size=10)

        # test config
        model.get_config()

        # check whether the model variables are present
        # in the trackable list of objects
        checkpointed_objects = object_identity.ObjectIdentitySet(
            trackable_util.list_objects(model))
        for v in model.variables:
            self.assertIn(v, checkpointed_objects)
 def testDictionariesBasic(self):
   a = training.Model()
   b = training.Model()
   a.attribute = {"b": b}
   c = training.Model()
   a.attribute["c"] = []
   a.attribute["c"].append(c)
   a_deps = util.list_objects(a)
   self.assertIn(b, a_deps)
   self.assertIn(c, a_deps)
   self.assertIs(b, a.attribute["b"])
   six.assertCountEqual(
       self,
       ["b", "c"],
       [dep.name for dep in a.attribute._checkpoint_dependencies])
   self.assertEqual([b, c], a.layers)
   self.assertEqual([b, c], a.attribute.layers)
   self.assertEqual([c], a.attribute["c"].layers)
   checkpoint = util.Checkpoint(a=a)
   save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
   with self.cached_session():
     checkpoint.restore(save_path).assert_consumed().initialize_or_restore()
Beispiel #44
0
    def test_model(self):
        model = tf.keras.models.Sequential()
        model.add(
            tf.keras.layers.Bidirectional(
                QRNN(units=12, window=2, zoneout=0.2, return_sequences=True)))
        model.add(QRNN(units=2, window=1))
        model.compile(optimizer='rmsprop',
                      loss='mse',
                      run_eagerly=testing_utils.should_run_eagerly())
        model.fit(np.random.random((10, 3, 4)),
                  np.random.random((10, 2)),
                  epochs=1,
                  batch_size=10)

        # test config
        model.get_config()

        # check whether the model variables are present in the trackable list of objects
        checkpointed_objects = object_identity.ObjectIdentitySet(
            trackable_util.list_objects(model))
        for v in model.variables:
            self.assertIn(v, checkpointed_objects)
Beispiel #45
0
 def testDictionariesBasic(self):
     a = training.Model()
     b = training.Model()
     a.attribute = {"b": b}
     c = training.Model()
     a.attribute["c"] = []
     a.attribute["c"].append(c)
     a_deps = util.list_objects(a)
     self.assertIn(b, a_deps)
     self.assertIn(c, a_deps)
     self.assertIs(b, a.attribute["b"])
     six.assertCountEqual(
         self, ["b", "c"],
         [dep.name for dep in a.attribute._checkpoint_dependencies])
     self.assertEqual([b, c], a.layers)
     self.assertEqual([b, c], a.attribute.layers)
     self.assertEqual([c], a.attribute["c"].layers)
     checkpoint = util.Checkpoint(a=a)
     save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
     with self.cached_session():
         checkpoint.restore(
             save_path).assert_consumed().initialize_or_restore()
Beispiel #46
0
    def test_bidirectional(self):
        rnn = keras.layers.SimpleRNN
        samples = 2
        dim = 2
        timesteps = 2
        output_dim = 2
        with self.cached_session():
            for mode in ['sum', 'concat', 'ave', 'mul']:
                x = np.random.random((samples, timesteps, dim))
                target_dim = 2 * output_dim if mode == 'concat' else output_dim
                y = np.random.random((samples, target_dim))

                # test with Sequential model
                model = keras.models.Sequential()
                model.add(
                    keras.layers.Bidirectional(rnn(output_dim),
                                               merge_mode=mode,
                                               input_shape=(timesteps, dim)))
                model.compile(optimizer='rmsprop', loss='mse')
                model.fit(x, y, epochs=1, batch_size=1)

                # check whether the model variables are present in the
                # trackable list of objects
                checkpointed_objects = object_identity.ObjectIdentitySet(
                    trackable_util.list_objects(model))
                for v in model.variables:
                    self.assertIn(v, checkpointed_objects)

                # test compute output shape
                ref_shape = model.layers[-1].output.shape
                shape = model.layers[-1].compute_output_shape(
                    (None, timesteps, dim))
                self.assertListEqual(shape.as_list(), ref_shape.as_list())

                # test config
                model.get_config()
                model = keras.models.model_from_json(model.to_json())
                model.summary()