def test_GraphKeys(self): v0 = de.Variable(key_dtype=dtypes.int64, value_dtype=dtypes.float32, initializer=0.0, name="v0") v1 = de.Variable(key_dtype=dtypes.int64, value_dtype=dtypes.float32, initializer=0.0, name="v1", trainable=False) v2 = de.get_variable( "v2", key_dtype=dtypes.int64, value_dtype=dtypes.float32, initializer=init_ops.zeros_initializer, dim=10, ) v3 = de.get_variable("v3", key_dtype=dtypes.int64, value_dtype=dtypes.float32, initializer=init_ops.zeros_initializer, dim=10, trainable=False) de_vars = ops.get_collection(de.GraphKeys.DYNAMIC_EMBEDDING_VARIABLES) self.assertSetEqual(set([v0, v1, v2, v3]), set(de_vars)) de_trainable_vars = ops.get_collection( de.GraphKeys.TRAINABLE_DYNAMIC_EMBEDDING_VARIABLES) self.assertAllEqual(set([v0, v2]), set(de_trainable_vars))
def create_slots(primary, init, slot_name, op_name): """Helper function for creating a slot variable for statefull optimizers.""" params_var_, params_ids_ = primary.params, primary.ids scope_store = variable_scope._get_default_variable_store() full_name = params_var_.name + "/" + op_name + "/" + slot_name if full_name not in scope_store._vars: with ops.colocate_with(primary, ignore_existing=True): slot_variable_ = de.Variable( name=full_name, key_dtype=params_var_.key_dtype, value_dtype=params_var_.value_dtype, dim=params_var_.dim, devices=params_var_.devices, partitioner=params_var_.partition_fn, initializer=init, init_size=params_var_.init_size, trainable=False, checkpoint=params_var_.checkpoint, ) scope_store._vars[full_name] = slot_variable_ slot_trainable = None _, slot_trainable = de.embedding_lookup( params=scope_store._vars[full_name], ids=params_ids_, name=slot_name, return_trainable=True, ) return slot_trainable
def create_slots(primary, init, slot_name, op_name, bp_v2): """Helper function for creating a slot variable for statefull optimizers.""" params_var_, params_ids_ = primary.params, primary.ids scope_store = variable_scope._get_default_variable_store() full_name = params_var_.name + "/" + op_name + "/" + slot_name if full_name not in scope_store._vars: with ops.colocate_with(primary, ignore_existing=True): slot_variable_ = de.Variable( name=full_name, key_dtype=params_var_.key_dtype, value_dtype=params_var_.value_dtype, dim=params_var_.dim, devices=params_var_.devices, partitioner=params_var_.partition_fn, initializer=init, kv_creator=params_var_.kv_creator, trainable=False, checkpoint=params_var_.checkpoint, bp_v2=bp_v2 if bp_v2 is not None else params_var_.bp_v2, ) scope_store._vars[full_name] = slot_variable_ # Record the optimizer Variable into trace. primary._optimizer_vars.append(slot_variable_) slot_trainable = None if context.executing_eagerly(): slot_tw_name = slot_name + '-' + str(optimizer_v2._var_key(primary)) else: # In graph mode of former version, It only uses slot_name as name to # trainable wrappers of slots. So here set it the name to slot_name # for forward compatibility. slot_tw_name = slot_name if isinstance(primary, de.shadow_ops.ShadowVariable): slot_trainable = de.shadow_ops.ShadowVariable( params=scope_store._vars[full_name], ids=primary.ids, exists=primary.exists, name=full_name, trainable=False, ) else: _, slot_trainable = de.embedding_lookup( params=scope_store._vars[full_name], ids=params_ids_, name=slot_tw_name, return_trainable=True, ) return slot_trainable
def test_saving_restoring_checkpoint(self): logdir = _test_dir(self.get_temp_dir(), "test_saving_restoring_checkpoint") with ops.Graph().as_default(): gstep = training_util.create_global_step() do_step = state_ops.assign_add(gstep, 1) v0 = variables.Variable(10.0, name="v0") v1 = variables.Variable(20.0, name="v1") target_values = [[0.0], [1.0], [2.0]] keys = array_ops.placeholder(dtypes.int64) values = constant_op.constant(target_values, dtypes.float32) table = de.Variable( key_dtype=dtypes.int64, value_dtype=dtypes.float32, initializer=-1.0, name="m100", dim=1, ) upsert_op = table.upsert(keys, values) lookup_op = table.lookup(keys) size_op = table.size() with monitored_session.MonitoredTrainingSession( config=default_config, is_chief=True, checkpoint_dir=logdir) as sess: self.assertEqual(0, sess.run(gstep)) self.assertEqual(1, sess.run(do_step)) self.assertEqual(2, sess.run(do_step)) # Check that the parameter nodes have been initialized. self.assertEqual(10.0, sess.run(v0)) self.assertEqual(20.0, sess.run(v1)) self.assertAllEqual(0, sess.run(size_op)) sess.run(upsert_op, feed_dict={keys: [0, 1, 2]}) self.assertAllEqual(3, sess.run(size_op)) self.device_check(table) # A restart will find the checkpoint and recover automatically. with monitored_session.MonitoredTrainingSession( config=default_config, is_chief=True, checkpoint_dir=logdir) as sess: self.assertEqual(2, sess.run(gstep)) self.assertAllEqual(3, sess.run(table.size())) self.assertAllEqual( target_values, sess.run(lookup_op, feed_dict={keys: [0, 1, 2]})) self.device_check(table)
def create_slots(primary, init, slot_name, op_name): """Helper function for creating a slot variable for statefull optimizers.""" # lwk @primary是一个de.TrainableWrapper # lwk 所以@params_var_是de.Variable,params_ids_是特征ID params_var_, params_ids_ = primary.params, primary.ids scope_store = variable_scope._get_default_variable_store() full_name = params_var_.name + "/" + op_name + "/" + slot_name if full_name not in scope_store._vars: with ops.colocate_with(primary, ignore_existing=True): # lwk de.Variable的slot variable也是一个de.Variable,也就是是一个哈希表 slot_variable_ = de.Variable( name=full_name, key_dtype=params_var_.key_dtype, value_dtype=params_var_.value_dtype, dim=params_var_.dim, # 维度一样 devices=params_var_.devices, partitioner=params_var_.partition_fn, initializer=init, trainable=False, checkpoint=params_var_.checkpoint, ) scope_store._vars[full_name] = slot_variable_ # lwk 所以动态哈希表的原理,本质上就是将lookup的结果包装成一个variable # lwk 这个variable包含了哈希表和ids,优化器也使用这些信息,来构造一个对应的slot variable slot_trainable = None _, slot_trainable = de.embedding_lookup( params=scope_store._vars[full_name], ids=params_ids_, name=slot_name, return_trainable=True, ) return slot_trainable
def test_save_restore_only_table(self): save_dir = os.path.join(self.get_temp_dir(), "save_restore") save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") with self.session( config=default_config, graph=ops.Graph(), use_gpu=test_util.is_gpu_available(), ) as sess: v0 = variables.Variable(10.0, name="v0") v1 = variables.Variable(20.0, name="v1") default_val = -1 keys = constant_op.constant([0, 1, 2], dtypes.int64) values = constant_op.constant([[0], [1], [2]], dtypes.int32) table = de.Variable( dtypes.int64, dtypes.int32, name="t1", initializer=default_val, checkpoint=True, ) save = saver.Saver([table]) self.evaluate(variables.global_variables_initializer()) # Check that the parameter nodes have been initialized. self.assertEqual(10.0, self.evaluate(v0)) self.assertEqual(20.0, self.evaluate(v1)) self.assertAllEqual(0, self.evaluate(table.size())) self.evaluate(table.upsert(keys, values)) self.assertAllEqual(3, self.evaluate(table.size())) val = save.save(sess, save_path) self.assertIsInstance(val, six.string_types) self.assertEqual(save_path, val) del table with self.session( config=default_config, graph=ops.Graph(), use_gpu=test_util.is_gpu_available(), ) as sess: default_val = -1 table = de.Variable( dtypes.int64, dtypes.int32, name="t1", initializer=default_val, checkpoint=True, ) self.evaluate( table.upsert( constant_op.constant([0, 2], dtypes.int64), constant_op.constant([[12], [24]], dtypes.int32), )) self.assertAllEqual(2, self.evaluate(table.size())) save = saver.Saver([table._tables[0]]) # Restore the saved values in the parameter nodes. save.restore(sess, save_path) # Check that the parameter nodes have been restored. self.assertAllEqual(3, self.evaluate(table.size())) remove_keys = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64) output = table.lookup(remove_keys) self.assertAllEqual([[0], [1], [2], [-1], [-1]], self.evaluate(output)) del table
def test_save_restore(self): save_dir = os.path.join(self.get_temp_dir(), "save_restore") save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") with self.session(config=default_config, graph=ops.Graph()) as sess: v0 = variables.Variable(10.0, name="v0") v1 = variables.Variable(20.0, name="v1") keys = constant_op.constant([0, 1, 2], dtypes.int64) values = constant_op.constant([[0.0], [1.0], [2.0]], dtypes.float32) table = de.Variable( key_dtype=dtypes.int64, value_dtype=dtypes.float32, initializer=-1.0, name="t1", dim=1, ) save = saver.Saver(var_list=[v0, v1, table]) self.evaluate(variables.global_variables_initializer()) # Check that the parameter nodes have been initialized. self.assertEqual(10.0, self.evaluate(v0)) self.assertEqual(20.0, self.evaluate(v1)) self.assertAllEqual(0, self.evaluate(table.size())) self.evaluate(table.upsert(keys, values)) self.assertAllEqual(3, self.evaluate(table.size())) val = save.save(sess, save_path) self.assertIsInstance(val, six.string_types) self.assertEqual(save_path, val) del table with self.session(config=default_config, graph=ops.Graph()) as sess: v0 = variables.Variable(-1.0, name="v0") v1 = variables.Variable(-1.0, name="v1") table = de.Variable( name="t1", key_dtype=dtypes.int64, value_dtype=dtypes.float32, initializer=-1.0, dim=1, checkpoint=True, ) self.evaluate( table.upsert( constant_op.constant([0, 1], dtypes.int64), constant_op.constant([[12.0], [24.0]], dtypes.float32), )) size_op = table.size() self.assertAllEqual(2, self.evaluate(size_op)) save = saver.Saver(var_list=[v0, v1, table]) # Restore the saved values in the parameter nodes. save.restore(sess, save_path) # Check that the parameter nodes have been restored. self.assertEqual([10.0], self.evaluate(v0)) self.assertEqual([20.0], self.evaluate(v1)) self.assertAllEqual(3, self.evaluate(table.size())) remove_keys = constant_op.constant([5, 0, 1, 2, 6], dtypes.int64) output = table.lookup(remove_keys) self.assertAllEqual([[-1.0], [0.0], [1.0], [2.0], [-1.0]], self.evaluate(output)) del table