def testNestedLists(self): a = tracking.AutoTrackable() a.l = [] b = tracking.AutoTrackable() a.l.append([b]) c = tracking.AutoTrackable() a.l[0].append(c) a_deps = util.list_objects(a) self.assertIn(b, a_deps) self.assertIn(c, a_deps) a.l[0].append(1) d = tracking.AutoTrackable() a.l[0].append(d) a_deps = util.list_objects(a) self.assertIn(d, a_deps) self.assertIn(b, a_deps) self.assertIn(c, a_deps) self.assertNotIn(1, a_deps) e = tracking.AutoTrackable() f = tracking.AutoTrackable() a.l1 = [[], [e]] a.l1[0].append(f) a_deps = util.list_objects(a) self.assertIn(e, a_deps) self.assertIn(f, a_deps) checkpoint = util.Checkpoint(a=a) checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) a.l[0].append(data_structures.NoDependency([])) a.l[0][-1].append(5) checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) # Dirtying the inner list means the root object is unsaveable. a.l[0][1] = 2 with self.assertRaisesRegexp(ValueError, "A list element was replaced"): checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
def testNestedLists(self): a = tracking.AutoTrackable() a.l = [] b = tracking.AutoTrackable() a.l.append([b]) c = tracking.AutoTrackable() a.l[0].append(c) a_deps = util.list_objects(a) self.assertIn(b, a_deps) self.assertIn(c, a_deps) a.l[0].append(1) d = tracking.AutoTrackable() a.l[0].append(d) a_deps = util.list_objects(a) self.assertIn(d, a_deps) self.assertIn(b, a_deps) self.assertIn(c, a_deps) self.assertNotIn(1, a_deps) e = tracking.AutoTrackable() f = tracking.AutoTrackable() a.l1 = [[], [e]] a.l1[0].append(f) a_deps = util.list_objects(a) self.assertIn(e, a_deps) self.assertIn(f, a_deps) checkpoint = util.Checkpoint(a=a) checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) a.l[0].append(data_structures.NoDependency([])) a.l[0][-1].append(5) checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) # Dirtying the inner list means the root object is unsaveable. a.l[0][1] = 2 with self.assertRaisesRegexp(ValueError, "A list element was replaced"): checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
def testAddVariableOverwrite(self): root = base.Trackable() a = root._add_variable_with_custom_getter( name="v", shape=[], getter=variable_scope.get_variable) self.assertEqual([root, a], util.list_objects(root)) with ops.Graph().as_default(): b = root._add_variable_with_custom_getter( name="v", shape=[], overwrite=True, getter=variable_scope.get_variable) self.assertEqual([root, b], util.list_objects(root)) with ops.Graph().as_default(): with self.assertRaisesRegex(ValueError, "already declared as a dependency"): root._add_variable_with_custom_getter( name="v", shape=[], overwrite=False, getter=variable_scope.get_variable)
def test_model(self): num_classes = 5 model = build_deeplab_v3_plus(classes=num_classes, bone_arch='resnet_50', bone_init='imagenet', bone_train=False, aspp_filters=8, aspp_stride=16, low_filters=16, decoder_filters=4) model.compile(optimizer='sgd', loss='sparse_categorical_crossentropy', run_eagerly=test_utils.should_run_eagerly()) model.fit(np.random.random((2, 224, 224, 3)).astype(np.uint8), np.random.randint(0, num_classes, (2, 224, 224)), epochs=1, batch_size=10) # test config model.get_config() # check whether the model variables are present # in the trackable list of objects checkpointed_objects = object_identity.ObjectIdentitySet( trackable_util.list_objects(model)) for v in model.variables: self.assertIn(v, checkpointed_objects)
def testAddVariableOverwrite(self): root = base.Trackable() a = root._add_variable_with_custom_getter( name="v", shape=[], getter=variable_scope.get_variable) self.assertEqual([root, a], util.list_objects(root)) with ops.Graph().as_default(): b = root._add_variable_with_custom_getter( name="v", shape=[], overwrite=True, getter=variable_scope.get_variable) self.assertEqual([root, b], util.list_objects(root)) with ops.Graph().as_default(): with self.assertRaisesRegexp( ValueError, "already declared as a dependency"): root._add_variable_with_custom_getter( name="v", shape=[], overwrite=False, getter=variable_scope.get_variable)
def test_model(self): model = build_fba_matting() model.compile( optimizer='sgd', loss=['mse', None, None, None], run_eagerly=test_utils.should_run_eagerly()) model.fit( [ np.random.random((2, 240, 240, 3)).astype(np.uint8), np.random.random((2, 240, 240, 2)).astype(np.uint8), np.random.random((2, 240, 240, 6)).astype(np.uint8), ], [ np.random.random((2, 240, 240, 7)).astype(np.float32), np.random.random((2, 240, 240, 1)).astype(np.float32), np.random.random((2, 240, 240, 3)).astype(np.float32), np.random.random((2, 240, 240, 3)).astype(np.float32) ], epochs=1, batch_size=10) # test config model.get_config() # check whether the model variables are present # in the trackable list of objects checkpointed_objects = object_identity.ObjectIdentitySet(trackable_util.list_objects(model)) for v in model.variables: self.assertIn(v, checkpointed_objects)
def test_timedistributed_dense(self): model = keras.models.Sequential() model.add( keras.layers.TimeDistributed( keras.layers.Dense(2), input_shape=(3, 4) ) ) model.compile(optimizer="rmsprop", loss="mse") model.fit( np.random.random((10, 3, 4)), np.random.random((10, 3, 2)), epochs=1, batch_size=10, ) # test config model.get_config() # check whether the model variables are present in the # trackable list of objects checkpointed_object_ids = { id(o) for o in trackable_util.list_objects(model) } for v in model.variables: self.assertIn(id(v), checkpointed_object_ids)
def test_trackable_save_restore(self): with self.test_session(): def _templated(): v = tf.compat.v1.get_variable( "v", shape=[1], initializer=tf.compat.v1.zeros_initializer(), use_resource=True) v2 = tf.compat.v1.get_variable( "v2", shape=[1], initializer=tf.compat.v1.zeros_initializer(), use_resource=True) manual = _ManualScope() return v, v + 1., v2, manual, manual() save_template = tf.compat.v1.make_template("s1", _templated) v1_save, _, v2_save, manual_scope, manual_scope_v = save_template() self.assertEqual( set([ id(v1_save), id(v2_save), id(manual_scope), id(manual_scope_v), id(save_template) ]), set(map(id, trackable_utils.list_objects(save_template)))) self.assertDictEqual({"in_manual_scope": manual_scope_v}, manual_scope._trackable_children()) optimizer = adam.Adam(0.0) save_root = tf.train.Checkpoint(my_template=save_template, optimizer=optimizer) optimizer.minimize(v1_save.read_value, var_list=[v1_save]) self.evaluate([v.initializer for v in save_template.variables]) optimizer_variables = optimizer.variables() + list( optimizer._hyper.values()) self.evaluate([v.initializer for v in optimizer_variables]) self.evaluate(v1_save.assign([12.])) self.evaluate(v2_save.assign([14.])) checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") save_path = save_root.save(checkpoint_prefix) load_template = tf.compat.v1.make_template("s2", _templated) load_optimizer = adam.Adam(0.0) load_root = tf.train.Checkpoint(my_template=load_template, optimizer=load_optimizer) status = load_root.restore(save_path) var, var_plus_one, var2, _, _ = load_template() load_optimizer.minimize(var.read_value, var_list=[var]) children = load_template._trackable_children() self.assertEqual({"v", "v2", "ManualScope"}, children.keys()) status.assert_consumed().run_restore_ops() self.assertAllEqual([12.], self.evaluate(var)) self.assertAllEqual([13.], self.evaluate(var_plus_one)) self.assertAllEqual([14.], self.evaluate(var2))
def testDictWrapperNoDependency(self): a = tracking.AutoTrackable() a.d = data_structures.NoDependency({}) a.d[1] = [3] self.assertEqual([a], util.list_objects(a)) model = training.Model() model.sub = a save_path = os.path.join(self.get_temp_dir(), "ckpt") model.save_weights(save_path) model.load_weights(save_path)
def testDictWrapperNoDependency(self): a = tracking.AutoTrackable() a.d = data_structures.NoDependency({}) a.d[1] = [3] self.assertEqual([a], util.list_objects(a)) model = training.Model() model.sub = a save_path = os.path.join(self.get_temp_dir(), "ckpt") model.save_weights(save_path) model.load_weights(save_path)
def testNonStringKeyNotTrackableValue(self): a = tracking.AutoTrackable() a.d = {} a.d["a"] = [3] a.d[1] = data_structures.NoDependency([3]) self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a)) model = training.Model() model.sub = a save_path = os.path.join(self.get_temp_dir(), "ckpt") model.save_weights(save_path) model.load_weights(save_path)
def testNonStringKeyNotTrackableValue(self): a = tracking.AutoTrackable() a.d = {} a.d["a"] = [3] a.d[1] = data_structures.NoDependency([3]) self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a)) model = training.Model() model.sub = a save_path = os.path.join(self.get_temp_dir(), "ckpt") model.save_weights(save_path) model.load_weights(save_path)
def test_trackable_save_restore(self): def _templated(): v = variable_scope.get_variable( "v", shape=[1], initializer=init_ops.zeros_initializer(), use_resource=True) v2 = variable_scope.get_variable( "v2", shape=[1], initializer=init_ops.zeros_initializer(), use_resource=True) manual = _ManualScope() return v, v + 1., v2, manual, manual() save_template = template.make_template("s1", _templated) v1_save, _, v2_save, manual_scope, manual_scope_v = save_template() six.assertCountEqual( self, [id(v1_save), id(v2_save), id(manual_scope), id(manual_scope_v), id(save_template)], map(id, trackable_utils.list_objects(save_template))) manual_dep, = manual_scope._checkpoint_dependencies self.assertEqual("in_manual_scope", manual_dep.name) self.assertIs(manual_scope_v, manual_dep.ref) optimizer = adam.Adam(0.0) save_root = trackable_utils.Checkpoint( my_template=save_template, optimizer=optimizer) optimizer.minimize(v1_save.read_value, var_list=[v1_save]) self.evaluate([v.initializer for v in save_template.variables]) optimizer_variables = optimizer.variables() + list( optimizer._hyper.values()) self.evaluate([v.initializer for v in optimizer_variables]) self.evaluate(v1_save.assign([12.])) self.evaluate(v2_save.assign([14.])) checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") save_path = save_root.save(checkpoint_prefix) load_template = template.make_template("s2", _templated) load_optimizer = adam.Adam(0.0) load_root = trackable_utils.Checkpoint( my_template=load_template, optimizer=load_optimizer) status = load_root.restore(save_path) var, var_plus_one, var2, _, _ = load_template() load_optimizer.minimize(var.read_value, var_list=[var]) self.assertLen(load_template._checkpoint_dependencies, 3) self.assertEqual("v", load_template._checkpoint_dependencies[0].name) self.assertEqual("v2", load_template._checkpoint_dependencies[1].name) self.assertEqual("ManualScope", load_template._checkpoint_dependencies[2].name) status.assert_consumed().run_restore_ops() self.assertAllEqual([12.], self.evaluate(var)) self.assertAllEqual([13.], self.evaluate(var_plus_one)) self.assertAllEqual([14.], self.evaluate(var2))
def testShallowCopyTrackable(self): original = tracking.AutoTrackable() original_sub = tracking.AutoTrackable() original.a = [[1.]] original.b = {"a": original_sub} shallow_copied = copy.copy(original) self.assertIs(original_sub, shallow_copied.b["a"]) self.assertIsNot(original, shallow_copied) self.assertEqual([[1.]], shallow_copied.a) shallow_deps = util.list_objects(shallow_copied) self.assertIn(shallow_copied.a, shallow_deps) self.assertIn(shallow_copied.b, shallow_deps) self.assertIn(shallow_copied.b["a"], shallow_deps)
def testShallowCopyTrackable(self): original = tracking.AutoTrackable() original_sub = tracking.AutoTrackable() original.a = [[1.]] original.b = {"a": original_sub} shallow_copied = copy.copy(original) self.assertIs(original_sub, shallow_copied.b["a"]) self.assertIsNot(original, shallow_copied) self.assertEqual([[1.]], shallow_copied.a) shallow_deps = util.list_objects(shallow_copied) self.assertIn(shallow_copied.a, shallow_deps) self.assertIn(shallow_copied.b, shallow_deps) self.assertIn(shallow_copied.b["a"], shallow_deps)
def testListBasic(self): a = autotrackable.AutoTrackable() b = autotrackable.AutoTrackable() a.l = [b] c = autotrackable.AutoTrackable() a.l.append(c) a_deps = util.list_objects(a) self.assertIn(b, a_deps) self.assertIn(c, a_deps) self.assertIn("l", a._trackable_children()) direct_a_dep = a._trackable_children()["l"] self.assertIn(b, direct_a_dep) self.assertIn(c, direct_a_dep)
def testListBasic(self): a = tracking.AutoTrackable() b = tracking.AutoTrackable() a.l = [b] c = tracking.AutoTrackable() a.l.append(c) a_deps = util.list_objects(a) self.assertIn(b, a_deps) self.assertIn(c, a_deps) direct_a_dep, = a._checkpoint_dependencies self.assertEqual("l", direct_a_dep.name) self.assertIn(b, direct_a_dep.ref) self.assertIn(c, direct_a_dep.ref)
def testListBasic(self): a = tracking.AutoTrackable() b = tracking.AutoTrackable() a.l = [b] c = tracking.AutoTrackable() a.l.append(c) a_deps = util.list_objects(a) self.assertIn(b, a_deps) self.assertIn(c, a_deps) direct_a_dep, = a._checkpoint_dependencies self.assertEqual("l", direct_a_dep.name) self.assertIn(b, direct_a_dep.ref) self.assertIn(c, direct_a_dep.ref)
def test_trackable_save_restore(self): def _templated(): v = variable_scope.get_variable( "v", shape=[1], initializer=init_ops.zeros_initializer(), use_resource=True) v2 = variable_scope.get_variable( "v2", shape=[1], initializer=init_ops.zeros_initializer(), use_resource=True) manual = _ManualScope() return v, v + 1., v2, manual, manual() save_template = template.make_template("s1", _templated) v1_save, _, v2_save, manual_scope, manual_scope_v = save_template() six.assertCountEqual( self, [v1_save, v2_save, manual_scope, manual_scope_v, save_template], trackable_utils.list_objects(save_template)) manual_dep, = manual_scope._checkpoint_dependencies self.assertEqual("in_manual_scope", manual_dep.name) self.assertIs(manual_scope_v, manual_dep.ref) optimizer = adam.AdamOptimizer(0.0) save_root = trackable_utils.Checkpoint( my_template=save_template, optimizer=optimizer) optimizer.minimize(v1_save.read_value) self.evaluate([v.initializer for v in save_template.variables]) self.evaluate([v.initializer for v in optimizer.variables()]) self.evaluate(v1_save.assign([12.])) self.evaluate(v2_save.assign([14.])) checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") save_path = save_root.save(checkpoint_prefix) load_template = template.make_template("s2", _templated) load_optimizer = adam.AdamOptimizer(0.0) load_root = trackable_utils.Checkpoint( my_template=load_template, optimizer=load_optimizer) status = load_root.restore(save_path) var, var_plus_one, var2, _, _ = load_template() load_optimizer.minimize(var.read_value) self.assertEqual(3, len(load_template._checkpoint_dependencies)) self.assertEqual("v", load_template._checkpoint_dependencies[0].name) self.assertEqual("v2", load_template._checkpoint_dependencies[1].name) self.assertEqual("ManualScope", load_template._checkpoint_dependencies[2].name) status.assert_consumed().run_restore_ops() self.assertAllEqual([12.], self.evaluate(var)) self.assertAllEqual([13.], self.evaluate(var_plus_one)) self.assertAllEqual([14.], self.evaluate(var2))
def testDeepCopyTrackable(self): original = tracking.AutoTrackable() original_sub = tracking.AutoTrackable() original.a = [[1.]] original.b = {"a": original_sub} deep_copied = copy.deepcopy(original) self.assertIsNot(original, deep_copied) self.assertIsNot(original_sub, deep_copied.b["a"]) self.assertEqual([[1.]], deep_copied.a) self.assertIsInstance(deep_copied.b["a"], tracking.AutoTrackable) deps = util.list_objects(deep_copied) self.assertIn(deep_copied.a, deps) self.assertIn(deep_copied.b, deps) self.assertIn(deep_copied.b["a"], deps) self.assertNotIn(original_sub, deps)
def test_model(self): model = build_uper_net(classes=2, bone_arch='swin_tiny_224', bone_init='imagenet', bone_train=False) model.compile(optimizer='sgd', loss='mse', run_eagerly=test_utils.should_run_eagerly()) model.fit( np.random.random((2, 240, 240, 3)).astype(np.uint8), np.random.random((2, 240, 240, 2)).astype(np.float32), epochs=1, batch_size=10) # test config model.get_config() # check whether the model variables are present # in the trackable list of objects checkpointed_objects = object_identity.ObjectIdentitySet(trackable_util.list_objects(model)) for v in model.variables: self.assertIn(v, checkpointed_objects)
def test_trackable_save_restore(self): def _templated(): v = variable_scope.get_variable( "v", shape=[1], initializer=init_ops.zeros_initializer(), use_resource=True) v2 = variable_scope.get_variable( "v2", shape=[1], initializer=init_ops.zeros_initializer(), use_resource=True) manual = _ManualScope() return v, v + 1., v2, manual, manual() save_template = template.make_template("s1", _templated) v1_save, _, v2_save, manual_scope, manual_scope_v = save_template() six.assertCountEqual(self, [ id(obj) for obj in [v1_save, v2_save, manual_scope, manual_scope_v, save_template] ], [id(obj) for obj in trackable_utils.list_objects(save_template)]) self.assertDictEqual({"in_manual_scope": manual_scope_v}, manual_scope._trackable_children()) optimizer = adam.AdamOptimizer(0.0) save_root = trackable_utils.Checkpoint( my_template=save_template, optimizer=optimizer) optimizer.minimize(v1_save.read_value) self.evaluate([v.initializer for v in save_template.variables]) self.evaluate([v.initializer for v in optimizer.variables()]) self.evaluate(v1_save.assign([12.])) self.evaluate(v2_save.assign([14.])) checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") save_path = save_root.save(checkpoint_prefix) load_template = template.make_template("s2", _templated) load_optimizer = adam.AdamOptimizer(0.0) load_root = trackable_utils.Checkpoint( my_template=load_template, optimizer=load_optimizer) status = load_root.restore(save_path) var, var_plus_one, var2, _, _ = load_template() load_optimizer.minimize(var.read_value) self.assertEqual(3, len(load_template._trackable_children())) self.assertEqual(set(["v", "v2", "ManualScope"]), load_template._trackable_children().keys()) status.assert_consumed().run_restore_ops() self.assertAllEqual([12.], self.evaluate(var)) self.assertAllEqual([13.], self.evaluate(var_plus_one)) self.assertAllEqual([14.], self.evaluate(var2))
def testDeepCopyTrackable(self): original = tracking.AutoTrackable() original_sub = tracking.AutoTrackable() original.a = [[1.]] original.b = {"a": original_sub} self.assertIsInstance(original.b, dict) deep_copied = copy.deepcopy(original) self.assertIsInstance(deep_copied.b, dict) self.assertIsNot(original, deep_copied) self.assertIsNot(original_sub, deep_copied.b["a"]) self.assertEqual([[1.]], deep_copied.a) self.assertIsInstance(deep_copied.b["a"], tracking.AutoTrackable) deps = util.list_objects(deep_copied) self.assertIn(deep_copied.a, deps) self.assertIn(deep_copied.b, deps) self.assertIn(deep_copied.b["a"], deps) self.assertNotIn(original_sub, deps)
def test_trackable_dependencies(self): rnn = keras.layers.SimpleRNN x = np.random.random((2, 2, 2)) y = np.random.random((2, 2)) model = keras.models.Sequential() model.add(rnn(2)) model.compile( optimizer='rmsprop', loss='mse', run_eagerly=testing_utils.should_run_eagerly()) model.fit(x, y, epochs=1, batch_size=1) # check whether the model variables are present in the # trackable list of objects checkpointed_objects = set(trackable_util.list_objects(model)) for v in model.variables: self.assertIn(v, checkpointed_objects)
def test_trackable_dependencies(self): rnn = keras.layers.SimpleRNN x = np.random.random((2, 2, 2)) y = np.random.random((2, 2)) model = keras.models.Sequential() model.add(rnn(2)) model.compile( optimizer='rmsprop', loss='mse', run_eagerly=testing_utils.should_run_eagerly()) model.fit(x, y, epochs=1, batch_size=1) # check whether the model variables are present in the # trackable list of objects checkpointed_objects = object_identity.ObjectIdentitySet( trackable_util.list_objects(model)) for v in model.variables: self.assertIn(v, checkpointed_objects)
def test_weight_norm_dense(self): model = tf.keras.models.Sequential() model.add(WeightNorm(tf.keras.layers.Dense(2), input_shape=(3, 4))) model.compile(optimizer='rmsprop', loss='mse', run_eagerly=testing_utils.should_run_eagerly()) model.fit(np.random.random((10, 3, 4)), np.random.random((10, 3, 2)), epochs=1, batch_size=10) # test config model.get_config() # check whether the model variables are present in the trackable list of objects checkpointed_objects = object_identity.ObjectIdentitySet( trackable_util.list_objects(model)) for v in model.variables: self.assertIn(v, checkpointed_objects)
def test_custom_backward_layer(self, mode): rnn = keras.layers.SimpleRNN samples = 2 dim = 2 timesteps = 2 output_dim = 2 x = np.random.random((samples, timesteps, dim)) target_dim = 2 * output_dim if mode == "concat" else output_dim y = np.random.random((samples, target_dim)) forward_layer = rnn(output_dim) backward_layer = rnn(output_dim, go_backwards=True) # test with Sequential model model = keras.models.Sequential() model.add( keras.layers.Bidirectional( forward_layer, merge_mode=mode, backward_layer=backward_layer, input_shape=(timesteps, dim), )) model.compile(optimizer="rmsprop", loss="mse") model.fit(x, y, epochs=1, batch_size=1) # check whether the model variables are present in the # trackable list of objects checkpointed_object_ids = { id(o) for o in trackable_util.list_objects(model) } for v in model.variables: self.assertIn(id(v), checkpointed_object_ids) # test compute output shape ref_shape = model.layers[-1].output.shape shape = model.layers[-1].compute_output_shape((None, timesteps, dim)) self.assertListEqual(shape.as_list(), ref_shape.as_list()) # test config model.get_config() model = keras.models.model_from_json(model.to_json()) model.summary()
def testDictionariesBasic(self): a = training.Model() b = training.Model() a.attribute = {"b": b} c = training.Model() a.attribute["c"] = [] a.attribute["c"].append(c) a_deps = util.list_objects(a) self.assertIn(b, a_deps) self.assertIn(c, a_deps) self.assertIs(b, a.attribute["b"]) self.assertEqual({"b", "c"}, a.attribute._trackable_children().keys()) self.assertEqual([b, c], a.layers) self.assertEqual([b, c], a.attribute.layers) self.assertEqual([c], a.attribute["c"].layers) checkpoint = tf.train.Checkpoint(a=a) save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) with self.cached_session(): checkpoint.restore( save_path).assert_consumed().initialize_or_restore()
def test_model(self): model = tf.keras.models.Sequential() model.add( TemporalConvNet(filters=[5, 4, 3], kernel_size=3, dropout=0.2)) model.compile(optimizer='rmsprop', loss='mse', run_eagerly=testing_utils.should_run_eagerly()) model.fit(np.random.random((10, 3, 4)), np.random.random((10, 3, 3)), epochs=1, batch_size=10) # test config model.get_config() # check whether the model variables are present in the trackable list of objects checkpointed_objects = object_identity.ObjectIdentitySet( trackable_util.list_objects(model)) for v in model.variables: self.assertIn(v, checkpointed_objects)
def test_model(self): num_classes = 1 model = build_dexi_ned(num_classes) model.compile(optimizer='sgd', loss='binary_crossentropy', run_eagerly=test_utils.should_run_eagerly()) model.fit(np.random.random((2, 224, 224, 3)).astype(np.uint8), np.random.randint(0, num_classes, (2, 224, 224)), epochs=1, batch_size=1) # test config model.get_config() # check whether the model variables are present # in the trackable list of objects checkpointed_objects = object_identity.ObjectIdentitySet( trackable_util.list_objects(model)) for v in model.variables: self.assertIn(v, checkpointed_objects)
def testNonAppendNotTrackable(self): # Non-append mutations (deleting or overwriting values) are OK when the # values aren't tracked. a = tracking.AutoTrackable() a.d = {} a.d["a"] = [3] a.d[1] = 3 a.d[1] = 2 self.assertEqual(2, a.d[1]) del a.d[1] a.d[2] = data_structures.NoDependency(tracking.AutoTrackable()) second = tracking.AutoTrackable() a.d[2] = data_structures.NoDependency(second) self.assertIs(second, a.d[2]) self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a)) model = training.Model() model.sub = a save_path = os.path.join(self.get_temp_dir(), "ckpt") model.save_weights(save_path) model.load_weights(save_path)
def testNonAppendNotTrackable(self): # Non-append mutations (deleting or overwriting values) are OK when the # values aren't tracked. a = tracking.AutoTrackable() a.d = {} a.d["a"] = [3] a.d[1] = 3 a.d[1] = 2 self.assertEqual(2, a.d[1]) del a.d[1] a.d[2] = data_structures.NoDependency(tracking.AutoTrackable()) second = tracking.AutoTrackable() a.d[2] = data_structures.NoDependency(second) self.assertIs(second, a.d[2]) self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a)) model = training.Model() model.sub = a save_path = os.path.join(self.get_temp_dir(), "ckpt") model.save_weights(save_path) model.load_weights(save_path)
def testNoDependency(self): root = tf.Module() hasdep = tf.Module() root.hasdep = hasdep nodep = tf.Module() root.nodep = data_structures.NoDependency(nodep) self.assertLen(root._trackable_children(), 1) self.assertIs(root._trackable_children()["hasdep"], root.hasdep) self.assertIs(root.hasdep, hasdep) self.assertIs(root.nodep, nodep) class NoDependencyModel(training.Model): @tf.__internal__.tracking.no_automatic_dependency_tracking def __init__(self): super(NoDependencyModel, self).__init__() self.a = [] self.b = tf.Module() nodeps = NoDependencyModel() self.assertEqual([nodeps], util.list_objects(nodeps))
def testNoDependency(self): root = module.Module() hasdep = module.Module() root.hasdep = hasdep nodep = module.Module() root.nodep = data_structures.NoDependency(nodep) self.assertEqual(1, len(root._checkpoint_dependencies)) self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep) self.assertIs(root.hasdep, hasdep) self.assertIs(root.nodep, nodep) class NoDependencyModel(training.Model): @base.no_automatic_dependency_tracking def __init__(self): super(NoDependencyModel, self).__init__() self.a = [] self.b = module.Module() nodeps = NoDependencyModel() self.assertEqual([nodeps], util.list_objects(nodeps))
def testNoDependency(self): root = tracking.AutoTrackable() hasdep = tracking.AutoTrackable() root.hasdep = hasdep nodep = tracking.AutoTrackable() root.nodep = data_structures.NoDependency(nodep) self.assertEqual(1, len(root._checkpoint_dependencies)) self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep) self.assertIs(root.hasdep, hasdep) self.assertIs(root.nodep, nodep) class NoDependencyModel(training.Model): @base.no_automatic_dependency_tracking def __init__(self): super(NoDependencyModel, self).__init__() self.a = [] self.b = tracking.AutoTrackable() nodeps = NoDependencyModel() self.assertEqual([nodeps], util.list_objects(nodeps))
def test_timedistributed_dense(self): model = keras.models.Sequential() model.add( keras.layers.TimeDistributed( keras.layers.Dense(2), input_shape=(3, 4))) model.compile(optimizer='rmsprop', loss='mse') model.fit( np.random.random((10, 3, 4)), np.random.random((10, 3, 2)), epochs=1, batch_size=10) # test config model.get_config() # check whether the model variables are present in the # trackable list of objects checkpointed_objects = object_identity.ObjectIdentitySet( trackable_util.list_objects(model)) for v in model.variables: self.assertIn(v, checkpointed_objects)
def test_custom_backward_layer(self, mode): rnn = keras.layers.SimpleRNN samples = 2 dim = 2 timesteps = 2 output_dim = 2 x = np.random.random((samples, timesteps, dim)) target_dim = 2 * output_dim if mode == 'concat' else output_dim y = np.random.random((samples, target_dim)) forward_layer = rnn(output_dim) backward_layer = rnn(output_dim, go_backwards=True) # test with Sequential model model = keras.models.Sequential() model.add( keras.layers.Bidirectional( forward_layer, merge_mode=mode, backward_layer=backward_layer, input_shape=(timesteps, dim))) model.compile(optimizer='rmsprop', loss='mse') model.fit(x, y, epochs=1, batch_size=1) # check whether the model variables are present in the # trackable list of objects checkpointed_objects = object_identity.ObjectIdentitySet( trackable_util.list_objects(model)) for v in model.variables: self.assertIn(v, checkpointed_objects) # test compute output shape ref_shape = model.layers[-1].output.shape shape = model.layers[-1].compute_output_shape((None, timesteps, dim)) self.assertListEqual(shape.as_list(), ref_shape.as_list()) # test config model.get_config() model = keras.models.model_from_json(model.to_json()) model.summary()
def testListDeepCopy(self): root = tracking.AutoTrackable() orig_list = [[1.]] root.a = orig_list copied = copy.deepcopy(root.a) self.assertAllEqual([[1.]], copied) self.assertIsNot(root.a, copied) self.assertIsNot(root.a[0], copied[0]) # Dirtiness should be inherited util.list_objects(root.a) orig_list.append(1.) with self.assertRaises(ValueError): util.list_objects(root.a) with self.assertRaises(ValueError): util.list_objects(copy.deepcopy(root.a))
def testDictDeepCopy(self): root = tracking.AutoTrackable() orig_dict = {"a": [1.]} root.a = orig_dict copied = copy.deepcopy(root.a) self.assertAllEqual([1.], copied["a"]) self.assertIsNot(root.a, copied) self.assertIsNot(root.a["a"], copied["a"]) # Dirtiness should be inherited util.list_objects(root.a) orig_dict["b"] = [] with self.assertRaises(ValueError): util.list_objects(root.a) with self.assertRaises(ValueError): util.list_objects(copy.deepcopy(root.a))
def testListDeepCopy(self): root = tracking.AutoTrackable() orig_list = [[1.]] root.a = orig_list copied = copy.deepcopy(root.a) self.assertAllEqual([[1.]], copied) self.assertIsNot(root.a, copied) self.assertIsNot(root.a[0], copied[0]) # Dirtiness should be inherited util.list_objects(root.a) orig_list.append(1.) with self.assertRaises(ValueError): util.list_objects(root.a) with self.assertRaises(ValueError): util.list_objects(copy.deepcopy(root.a))
def testDictDeepCopy(self): root = tracking.AutoTrackable() orig_dict = {"a": [1.]} root.a = orig_dict copied = copy.deepcopy(root.a) self.assertAllEqual([1.], copied["a"]) self.assertIsNot(root.a, copied) self.assertIsNot(root.a["a"], copied["a"]) # Dirtiness should be inherited util.list_objects(root.a) orig_dict["b"] = [] with self.assertRaises(ValueError): util.list_objects(root.a) with self.assertRaises(ValueError): util.list_objects(copy.deepcopy(root.a))
def test_model(self): num_classes = 5 model = build_deeplab_v3_plus_with_point_rend(classes=num_classes, bone_arch='resnet_50', bone_init='imagenet', bone_train=False, aspp_filters=8, aspp_stride=16, low_filters=16, decoder_filters=4, rend_strides=(2, 4), rend_units=(256, ), rend_points=(0.1697, 0.0005), rend_oversample=3, rend_importance=0.75, rend_weights=True, rend_corners=True) model.compile(optimizer='sgd', loss=['sparse_categorical_crossentropy', None], run_eagerly=test_utils.should_run_eagerly()) model.fit( { 'image': np.random.random((2, 224, 224, 3)).astype(np.uint8), 'label': np.random.randint(0, num_classes, (2, 224, 224)), 'weight': np.random.rand(2, 224, 224), }, np.random.randint(0, num_classes, (2, 224, 224)), epochs=1, batch_size=10) # test config model.get_config() # check whether the model variables are present # in the trackable list of objects checkpointed_objects = object_identity.ObjectIdentitySet( trackable_util.list_objects(model)) for v in model.variables: self.assertIn(v, checkpointed_objects)
def testDictionariesBasic(self): a = training.Model() b = training.Model() a.attribute = {"b": b} c = training.Model() a.attribute["c"] = [] a.attribute["c"].append(c) a_deps = util.list_objects(a) self.assertIn(b, a_deps) self.assertIn(c, a_deps) self.assertIs(b, a.attribute["b"]) six.assertCountEqual( self, ["b", "c"], [dep.name for dep in a.attribute._checkpoint_dependencies]) self.assertEqual([b, c], a.layers) self.assertEqual([b, c], a.attribute.layers) self.assertEqual([c], a.attribute["c"].layers) checkpoint = util.Checkpoint(a=a) save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) with self.cached_session(): checkpoint.restore(save_path).assert_consumed().initialize_or_restore()
def test_model(self): model = tf.keras.models.Sequential() model.add( tf.keras.layers.Bidirectional( QRNN(units=12, window=2, zoneout=0.2, return_sequences=True))) model.add(QRNN(units=2, window=1)) model.compile(optimizer='rmsprop', loss='mse', run_eagerly=testing_utils.should_run_eagerly()) model.fit(np.random.random((10, 3, 4)), np.random.random((10, 2)), epochs=1, batch_size=10) # test config model.get_config() # check whether the model variables are present in the trackable list of objects checkpointed_objects = object_identity.ObjectIdentitySet( trackable_util.list_objects(model)) for v in model.variables: self.assertIn(v, checkpointed_objects)
def testDictionariesBasic(self): a = training.Model() b = training.Model() a.attribute = {"b": b} c = training.Model() a.attribute["c"] = [] a.attribute["c"].append(c) a_deps = util.list_objects(a) self.assertIn(b, a_deps) self.assertIn(c, a_deps) self.assertIs(b, a.attribute["b"]) six.assertCountEqual( self, ["b", "c"], [dep.name for dep in a.attribute._checkpoint_dependencies]) self.assertEqual([b, c], a.layers) self.assertEqual([b, c], a.attribute.layers) self.assertEqual([c], a.attribute["c"].layers) checkpoint = util.Checkpoint(a=a) save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) with self.cached_session(): checkpoint.restore( save_path).assert_consumed().initialize_or_restore()
def test_bidirectional(self): rnn = keras.layers.SimpleRNN samples = 2 dim = 2 timesteps = 2 output_dim = 2 with self.cached_session(): for mode in ['sum', 'concat', 'ave', 'mul']: x = np.random.random((samples, timesteps, dim)) target_dim = 2 * output_dim if mode == 'concat' else output_dim y = np.random.random((samples, target_dim)) # test with Sequential model model = keras.models.Sequential() model.add( keras.layers.Bidirectional(rnn(output_dim), merge_mode=mode, input_shape=(timesteps, dim))) model.compile(optimizer='rmsprop', loss='mse') model.fit(x, y, epochs=1, batch_size=1) # check whether the model variables are present in the # trackable list of objects checkpointed_objects = object_identity.ObjectIdentitySet( trackable_util.list_objects(model)) for v in model.variables: self.assertIn(v, checkpointed_objects) # test compute output shape ref_shape = model.layers[-1].output.shape shape = model.layers[-1].compute_output_shape( (None, timesteps, dim)) self.assertListEqual(shape.as_list(), ref_shape.as_list()) # test config model.get_config() model = keras.models.model_from_json(model.to_json()) model.summary()