def test_sharded_multi_lookup_on_one_variable(self): embeddings = de.get_variable( "t340", dtypes.int64, dtypes.float32, devices=_get_devices() * 3, initializer=2.0, ) ids = constant_op.constant([0, 1, 2, 3, 4], dtype=dtypes.int64) vals = constant_op.constant([[0.0], [1.0], [2.0], [3.0], [4.0]], dtype=dtypes.float32) new_vals = constant_op.constant([[10.0], [11.0], [12.0], [13.0], [14.0]], dtype=dtypes.float32) ids0 = constant_op.constant([1, 3, 2], dtype=dtypes.int64) ids1 = constant_op.constant([3, 4], dtype=dtypes.int64) embedding0 = de.embedding_lookup(embeddings, ids0) embedding1 = de.embedding_lookup(embeddings, ids1) with self.session(use_gpu=test_util.is_gpu_available(), config=default_config): self.evaluate(embeddings.upsert(ids, vals)) self.assertAllClose(embedding0.eval(), [[1.0], [3.0], [2.0]]) self.assertAllEqual([3, 1], embedding0.eval().shape) self.assertAllClose(embedding1.eval(), [[3.0], [4.0]]) self.assertAllEqual([2, 1], embedding1.eval().shape) self.evaluate(embeddings.upsert(ids, new_vals)) self.assertAllClose(embedding1.eval(), [[13.0], [14.0]]) self.assertAllEqual([2, 1], embedding1.eval().shape)
def test_scope_reuse_embedding_lookup(self): ids = constant_op.constant([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=dtypes.int64) with variable_scope.variable_scope("test", reuse=variable_scope.AUTO_REUSE): p1 = de.get_variable(name="p1") with variable_scope.variable_scope("q"): _, t1 = de.embedding_lookup(p1, ids, name="emb", return_trainable=True) with variable_scope.variable_scope("test", reuse=variable_scope.AUTO_REUSE): p1_reuse = de.get_variable(name="p1") p2 = de.get_variable(name="p2") with variable_scope.variable_scope("q"): _, t2 = de.embedding_lookup(p2, ids, name="emb", return_trainable=True) self.assertAllEqual(p1.name, "test/p1") self.assertAllEqual(p2.name, "test/p2") self.assertAllEqual(p1, p1_reuse) self.assertEqual(t1.name, "test/q/emb/TrainableWrapper:0") self.assertEqual(t2.name, "test/q/emb/TrainableWrapper_1:0") self.assertAllEqual(p1._tables[0].name, "test_p1_mht_1of1") self.assertAllEqual(p1_reuse._tables[0].name, "test_p1_mht_1of1") self.assertAllEqual(p2._tables[0].name, "test_p2_mht_1of1")
def test_embedding_lookup_shape(self): def _evaluate(tensors, feed_dict): sess = ops.get_default_session() if sess is None: with self.test_session() as sess: return sess.run(tensors, feed_dict=feed_dict) else: return sess.run(tensors, feed_dict=feed_dict) with self.session(use_gpu=test_util.is_gpu_available(), config=default_config): default_val = -1 keys = constant_op.constant([0, 1, 2], dtypes.int64) values = constant_op.constant([[0, 0, 0], [1, 1, 1], [2, 2, 2]], dtypes.int32) table = de.get_variable("t140", dtypes.int64, dtypes.int32, dim=3, initializer=default_val) self.evaluate(table.upsert(keys, values)) self.assertAllEqual(3, self.evaluate(table.size())) # shape of ids is fully defined ids = constant_op.constant([[0, 1], [2, 4]], dtypes.int64) embeddings = de.embedding_lookup(table, ids) self.assertAllEqual([2, 2, 3], embeddings.get_shape()) re = self.evaluate(embeddings) self.assertAllEqual([[[0, 0, 0], [1, 1, 1]], [[2, 2, 2], [-1, -1, -1]]], re) # shape of ids is partially defined ids = gen_array_ops.placeholder(shape=(2, None), dtype=dtypes.int64) embeddings = de.embedding_lookup(table, ids) self.assertFalse(embeddings.get_shape().is_fully_defined()) re = _evaluate( embeddings, feed_dict={ids: np.asarray([[0, 1], [2, 4]], dtype=np.int64)}) self.assertAllEqual([[[0, 0, 0], [1, 1, 1]], [[2, 2, 2], [-1, -1, -1]]], re) # shape of ids is unknown ids = gen_array_ops.placeholder(dtype=dtypes.int64) embeddings = de.embedding_lookup(table, ids) self.assertEqual(embeddings.get_shape(), tensor_shape.unknown_shape()) re = _evaluate( embeddings, feed_dict={ids: np.asarray([[0, 1], [2, 4]], dtype=np.int64)}) self.assertAllEqual([[[0, 0, 0], [1, 1, 1]], [[2, 2, 2], [-1, -1, -1]]], re)
def test_higher_rank(self): np.random.seed(8) with self.session(use_gpu=test_util.is_gpu_available(), config=default_config): for dim in [1, 10]: for ids_shape in [[3, 2], [4, 3], [4, 3, 10]]: with variable_scope.variable_scope("test_higher_rank", reuse=True): params = de.get_variable( "t350-" + str(dim), dtypes.int64, dtypes.float32, initializer=2.0, dim=dim, ) ids = np.random.randint( 2**31, size=np.prod(ids_shape), dtype=np.int).reshape(ids_shape) ids = constant_op.constant(ids, dtype=dtypes.int64) simple = params.lookup(ids) self.evaluate(params.upsert(ids, simple)) embedding = de.embedding_lookup(params, ids) self.assertAllEqual(simple.eval(), embedding.eval()) self.assertAllEqual(ids_shape + [dim], embedding.eval().shape)
def test_sharded_custom_partitioner_int32_ids(self): def _partition_fn(keys, shard_num): return math_ops.cast(keys % 2, dtype=dtypes.int32) embeddings = de.get_variable( "t330", dtypes.int64, dtypes.float32, partitioner=_partition_fn, devices=_get_devices() * 3, initializer=2.0, ) ids = constant_op.constant([0, 1, 2, 3, 4], dtype=dtypes.int64) vals = constant_op.constant([[0.0], [1.0], [2.0], [3.0], [4.0]], dtype=dtypes.float32) ids_test = constant_op.constant([1, 3, 2, 3, 0], dtype=dtypes.int64) embedding = de.embedding_lookup(embeddings, ids_test) with self.session(use_gpu=test_util.is_gpu_available(), config=default_config): self.evaluate(embeddings.upsert(ids, vals)) self.assertAllClose(embedding.eval(), [[1.0], [3.0], [2.0], [3.0], [0.0]]) self.assertAllEqual([5, 1], embedding.eval().shape) self.assertAllEqual(3, embeddings.size(0).eval()) self.assertAllEqual(2, embeddings.size(1).eval()) self.assertAllEqual(0, embeddings.size(2).eval())
def test_check_ops_number(self): self.assertTrue(de.get_model_mode() == "train") de.enable_inference_mode() self.assertTrue(de.get_model_mode() == "inference") de.enable_train_mode() self.assertTrue(de.get_model_mode() == "train") for fn, assign_num, read_num in [(de.enable_train_mode, 1, 2), (de.enable_inference_mode, 0, 1)]: fn() embeddings = de.get_variable('ModeModeTest' + str(assign_num), key_dtype=dtypes.int64, value_dtype=dtypes.float32, devices=_get_devices(), initializer=1., dim=8) ids = constant_op.constant([0, 1, 2, 3, 4], dtype=dtypes.int64) test_var, trainable = de.embedding_lookup([embeddings], ids, return_trainable=True) _ = math_ops.add(test_var, 1) op_list = ops.get_default_graph().get_operations() op_list_assign = [ op.name for op in op_list if "AssignBeforeReadVariable" in op.name ] op_list_read = [op.name for op in op_list if "ReadVariableOp" in op.name] self.assertTrue(len(op_list_assign) == assign_num) self.assertTrue(len(op_list_read) == read_num) de.enable_train_mode() ops.reset_default_graph()
def create_slots(primary, init, slot_name, op_name): """Helper function for creating a slot variable for statefull optimizers.""" params_var_, params_ids_ = primary.params, primary.ids scope_store = variable_scope._get_default_variable_store() full_name = params_var_.name + "/" + op_name + "/" + slot_name if full_name not in scope_store._vars: with ops.colocate_with(primary, ignore_existing=True): slot_variable_ = de.Variable( name=full_name, key_dtype=params_var_.key_dtype, value_dtype=params_var_.value_dtype, dim=params_var_.dim, devices=params_var_.devices, partitioner=params_var_.partition_fn, initializer=init, init_size=params_var_.init_size, trainable=False, checkpoint=params_var_.checkpoint, ) scope_store._vars[full_name] = slot_variable_ slot_trainable = None _, slot_trainable = de.embedding_lookup( params=scope_store._vars[full_name], ids=params_ids_, name=slot_name, return_trainable=True, ) return slot_trainable
def test_static_shape_checking(self): np.random.seed(8) with self.session(use_gpu=test_util.is_gpu_available(), config=default_config): for dim in [1, 10]: for ids_shape in [[3, 2], [4, 3], [4, 3, 10]]: with variable_scope.variable_scope( "test_static_shape_checking" + str(dim), reuse=variable_scope.AUTO_REUSE, ): params = de.get_variable( "test_static_shape_checking-" + str(dim), dtypes.int64, dtypes.float32, initializer=2.0, dim=dim, ) params_nn = variable_scope.get_variable("n", shape=[100, dim], use_resource=False) ids = np.random.randint(2**31, size=np.prod(ids_shape), dtype=np.int).reshape(ids_shape) ids = constant_op.constant(ids, dtype=dtypes.int64) embedding_test = de.embedding_lookup(params, ids) embedding_base = embedding_ops.embedding_lookup(params_nn, ids) self.assertAllEqual(embedding_test.shape, embedding_base.shape)
def commonly_apply_restriction_verify(self, optimizer): first_inputs = np.array(range(6), dtype=np.int64) second_inputs = np.array(range(4, 9), dtype=np.int64) overdue_features = np.array(range(4), dtype=np.int64) updated_features = np.array(range(4, 9), dtype=np.int64) all_input_features = np.array(range(9), dtype=np.int64) embedding_dim = 2 oversize_trigger = 100 optimizer = de.DynamicEmbeddingOptimizer(optimizer) with session.Session(config=default_config) as sess: ids = array_ops.placeholder(dtypes.int64) var = de.get_variable('sp_var', key_dtype=ids.dtype, value_dtype=dtypes.float32, initializer=-0.1, dim=embedding_dim, restrict_policy=de.TimestampRestrictPolicy) embed_w, trainable = de.embedding_lookup(var, ids, return_trainable=True, name='ut8900') loss = _simple_loss(embed_w) train_op = optimizer.minimize(loss, var_list=[trainable]) slot_params = [ optimizer.get_slot(trainable, name).params for name in optimizer.get_slot_names() ] all_vars = [var] + slot_params + [var.restrict_policy.status] sess.run(variables.global_variables_initializer()) sess.run([train_op], feed_dict={ids: first_inputs}) time.sleep(1) sess.run([train_op], feed_dict={ids: second_inputs}) for v in all_vars: self.assertAllEqual(sess.run(v.size()), 9) keys, tstp = sess.run(var.restrict_policy.status.export()) kvs = sorted(dict(zip(keys, tstp)).items()) tstp = np.array([x[1] for x in kvs]) for x in tstp[overdue_features]: for y in tstp[updated_features]: self.assertLess(x, y) sess.run( var.restrict_policy.apply_restriction( len(updated_features), trigger=oversize_trigger)) for v in all_vars: self.assertAllEqual(sess.run(v.size()), len(all_input_features)) sess.run( var.restrict_policy.apply_restriction( len(updated_features), trigger=len(updated_features))) for v in all_vars: self.assertAllEqual(sess.run(v.size()), len(updated_features)) keys, _ = sess.run(var.export()) keys_sorted = np.sort(keys) self.assertAllEqual(keys_sorted, updated_features)
def commonly_apply_update_verify_v2(self): if not context.executing_eagerly(): self.skipTest('Skip graph mode test.') first_inputs = np.array(range(6), dtype=np.int64) second_inputs = np.array(range(3, 9), dtype=np.int64) overdue_features = np.array([0, 1, 2, 6, 7, 8], dtype=np.int64) updated_features = np.array(range(3, 6), dtype=np.int64) all_features = np.array(range(9), dtype=np.int64) with self.session(config=default_config): var = de.get_variable('sp_var', key_dtype=dtypes.int64, value_dtype=dtypes.float32, initializer=-0.1, dim=2) embed_w, trainable = de.embedding_lookup(var, first_inputs, return_trainable=True, name='vc3939') policy = de.FrequencyRestrictPolicy(var) self.assertAllEqual(policy.status.size(), 0) policy.apply_update(first_inputs) self.assertAllEqual(policy.status.size(), len(first_inputs)) time.sleep(1) policy.apply_update(second_inputs) self.assertAllEqual(policy.status.size(), len(all_features)) keys, freq = policy.status.export() kvs = sorted(dict(zip(keys.numpy(), freq.numpy())).items()) freq = np.array([x[1] for x in kvs]) for x in freq[overdue_features]: for y in freq[updated_features]: self.assertLess(x, y)
def commonly_apply_update_verify(self): first_inputs = np.array(range(3), dtype=np.int64) second_inputs = np.array(range(1, 4), dtype=np.int64) overdue_features = np.array([0, 3], dtype=np.int64) updated_features = np.array(range(1, 3), dtype=np.int64) with session.Session(config=default_config) as sess: ids = array_ops.placeholder(dtypes.int64) var = de.get_variable('sp_var', key_dtype=ids.dtype, value_dtype=dtypes.float32, initializer=-0.1, dim=2) embed_w, trainable = de.embedding_lookup(var, ids, return_trainable=True, name='pl3201') policy = de.FrequencyRestrictPolicy(var) update_op = policy.apply_update(ids) self.assertAllEqual(sess.run(policy.status.size()), 0) sess.run(update_op, feed_dict={ids: first_inputs}) self.assertAllEqual(sess.run(policy.status.size()), 3) time.sleep(1) sess.run(update_op, feed_dict={ids: second_inputs}) self.assertAllEqual(sess.run(policy.status.size()), 4) keys, freq = sess.run(policy.status.export()) kvs = sorted(dict(zip(keys, freq)).items()) freq = np.array([x[1] for x in kvs]) for x in freq[overdue_features]: for y in freq[updated_features]: self.assertLess(x, y)
def loss_fn(var, features, trainables): embed_w, trainable = de.embedding_lookup(var, features, return_trainable=True) trainables.clear() trainables.append(trainable) return _simple_loss(embed_w)
def loss_fn(x, trainables): ids = constant_op.constant(raw_ids, dtype=k_dtype) pred, trainable = de.embedding_lookup( [x], ids, return_trainable=True) trainables.clear() trainables.append(trainable) return pred * pred
def test_simple_sharded(self): embeddings = de.get_variable( "t300", dtypes.int64, dtypes.float32, dim= 5, devices=_get_devices() * 2, initializer=2.0, ) ids = constant_op.constant([0, 1, 2, 3, 4], dtype=dtypes.int64) embedding, trainable = de.embedding_lookup(embeddings, ids, max_norm=1.0, return_trainable=True) with self.session(use_gpu=test_util.is_gpu_available(), config=default_config): self.assertAllClose( embedding.eval(), [ [1.0], ] * 5, ) self.evaluate(trainable.update_op()) self.assertAllEqual(embeddings.size().eval(), 5) self.assertAllEqual(embeddings.size(0).eval(), 3) self.assertAllEqual(embeddings.size(1).eval(), 2)
def test_max_norm_nontrivial(self): with self.session(use_gpu=test_util.is_gpu_available(), config=default_config): embeddings = de.get_variable("t320", dtypes.int64, dtypes.float32, initializer=2.0, dim=2) fake_values = constant_op.constant([[2.0, 4.0], [3.0, 1.0]]) ids = constant_op.constant([0, 1], dtype=dtypes.int64) self.evaluate(embeddings.upsert(ids, fake_values)) embedding_no_norm = de.embedding_lookup(embeddings, ids) embedding = de.embedding_lookup(embeddings, ids, max_norm=2.0) norms = math_ops.sqrt( math_ops.reduce_sum(embedding_no_norm * embedding_no_norm, axis=1)) normalized = embedding_no_norm / array_ops.stack([norms, norms], axis=1) self.assertAllEqual(embedding.eval(), 2 * self.evaluate(normalized))
def test_max_norm(self): with self.session(use_gpu=test_util.is_gpu_available(), config=default_config): embeddings = de.get_variable("t310", dtypes.int64, dtypes.float32, initializer=2.0) ids = constant_op.constant([0], dtype=dtypes.int64) embedding = de.embedding_lookup(embeddings, ids, max_norm=1.0) self.assertAllEqual(embedding.eval(), [[1.0]])
def test_dynamic_shape_checking(self): np.random.seed(8) with self.session(use_gpu=test_util.is_gpu_available(), config=default_config): for dim in [1, 10]: for ids_shape in [None, [-1, 1], [1, -1, 1], [-1, 1, 1]]: with variable_scope.variable_scope( "test_static_shape_checking" + str(dim), reuse=variable_scope.AUTO_REUSE, ): params = de.get_variable( "test_static_shape_checking-" + str(dim), dtypes.int64, dtypes.float32, initializer=2.0, dim=dim, ) params_nn = variable_scope.get_variable( "n", shape=[100, dim], use_resource=False) ids = script_ops.py_func( _create_dynamic_shape_tensor(min_val=0, max_val=100), inp=[], Tout=dtypes.int64, stateful=True, ) if ids_shape is not None: ids = array_ops.reshape(ids, ids_shape) embedding_test = de.embedding_lookup(params, ids) embedding_base = embedding_ops.embedding_lookup( params_nn, ids) # check static shape if ids_shape is None: # ids with unknown shape self.assertTrue( embedding_test.shape == embedding_base.shape) else: # ids with no fully-defined shape. self.assertAllEqual( embedding_test.shape.as_list(), embedding_base.shape.as_list(), ) self.evaluate(variables.global_variables_initializer()) # check static shape for _ in range(10): embedding_test_val, embedding_base_val = self.evaluate( [embedding_test, embedding_base]) self.assertAllEqual(embedding_test_val.shape, embedding_base_val.shape)
def create_slots(primary, init, slot_name, op_name, bp_v2): """Helper function for creating a slot variable for statefull optimizers.""" params_var_, params_ids_ = primary.params, primary.ids scope_store = variable_scope._get_default_variable_store() full_name = params_var_.name + "/" + op_name + "/" + slot_name if full_name not in scope_store._vars: with ops.colocate_with(primary, ignore_existing=True): slot_variable_ = de.Variable( name=full_name, key_dtype=params_var_.key_dtype, value_dtype=params_var_.value_dtype, dim=params_var_.dim, devices=params_var_.devices, partitioner=params_var_.partition_fn, initializer=init, kv_creator=params_var_.kv_creator, trainable=False, checkpoint=params_var_.checkpoint, bp_v2=bp_v2 if bp_v2 is not None else params_var_.bp_v2, ) scope_store._vars[full_name] = slot_variable_ # Record the optimizer Variable into trace. primary._optimizer_vars.append(slot_variable_) slot_trainable = None if context.executing_eagerly(): slot_tw_name = slot_name + '-' + str(optimizer_v2._var_key(primary)) else: # In graph mode of former version, It only uses slot_name as name to # trainable wrappers of slots. So here set it the name to slot_name # for forward compatibility. slot_tw_name = slot_name if isinstance(primary, de.shadow_ops.ShadowVariable): slot_trainable = de.shadow_ops.ShadowVariable( params=scope_store._vars[full_name], ids=primary.ids, exists=primary.exists, name=full_name, trainable=False, ) else: _, slot_trainable = de.embedding_lookup( params=scope_store._vars[full_name], ids=params_ids_, name=slot_tw_name, return_trainable=True, ) return slot_trainable
def test_treated_as_worker_op_by_device_setter(self): num_ps_tasks = 2 with ops.device("/job:worker/task:0"): ids = constant_op.constant([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=dtypes.int64) setter = device_setter.replica_device_setter(ps_tasks=num_ps_tasks, ps_device="/job:ps", worker_device="/job:worker") with ops.device(setter): p1 = de.get_variable(name="p1", devices=["/job:ps/task:0", "/job:ps/task:1"]) _ = de.embedding_lookup(p1, ids, name="emb") self.assertTrue("/job:ps/task:0" in p1._tables[0].resource_handle.device) self.assertTrue("/job:ps/task:1" in p1._tables[1].resource_handle.device)
def test_inference_numberic_correctness(self): train_pred = None infer_pred = None dim = 8 initializer = init_ops.random_normal_initializer(0.0, 0.001) raw_init_vals = np.random.rand(100, dim) for fn in [de.enable_train_mode, de.enable_inference_mode]: with ops.Graph().as_default(): fn() init_ids = constant_op.constant(list(range(100)), dtype=dtypes.int64) init_vals = constant_op.constant(raw_init_vals, dtype=dtypes.float32) with variable_scope.variable_scope("modelmode", reuse=variable_scope.AUTO_REUSE): embeddings = de.get_variable('ModelModeTest-numberic', key_dtype=dtypes.int64, value_dtype=dtypes.float32, devices=_get_devices() * 2, initializer=initializer, dim=dim) w = variables.Variable(1.0, name="w") _ = training_util.create_global_step() init_op = embeddings.upsert(init_ids, init_vals) ids = constant_op.constant([0, 1, 2, 3, 4], dtype=dtypes.int64) test_var, trainable = de.embedding_lookup([embeddings], ids, return_trainable=True) pred = math_ops.add(test_var, 1) * w loss = pred * pred opt = de.DynamicEmbeddingOptimizer(adagrad.AdagradOptimizer(0.1)) opt.minimize(loss) with monitored_session.MonitoredTrainingSession( is_chief=True, config=default_config) as sess: if de.get_model_mode() == de.ModelMode.TRAIN: sess.run(init_op) train_pred = sess.run(pred) elif de.get_model_mode() == de.ModelMode.INFERENCE: sess.run(init_op) infer_pred = sess.run(pred) de.enable_train_mode() ops.reset_default_graph() self.assertAllEqual(train_pred, infer_pred)
def _test_warm_start_rename(self, num_shards, use_regex): devices = ["/cpu:0" for _ in range(num_shards)] ckpt_prefix = os.path.join(self.get_temp_dir(), "ckpt") id_list = [x for x in range(100)] val_list = [[x] for x in range(100)] emb_name = "t200_{}_{}".format(num_shards, use_regex) with self.session(graph=ops.Graph()) as sess: embeddings = de.get_variable("save_{}".format(emb_name), dtypes.int64, dtypes.float32, devices=devices, initializer=0.0) ids = constant_op.constant(id_list, dtype=dtypes.int64) vals = constant_op.constant(val_list, dtype=dtypes.float32) self.evaluate(embeddings.upsert(ids, vals)) save = saver.Saver(var_list=[embeddings]) save.save(sess, ckpt_prefix) with self.session(graph=ops.Graph()) as sess: embeddings = de.get_variable("restore_{}".format(emb_name), dtypes.int64, dtypes.float32, devices=devices, initializer=0.0) ids = constant_op.constant(id_list, dtype=dtypes.int64) emb = de.embedding_lookup(embeddings, ids, name="lookup") sess.graph.add_to_collection( de.GraphKeys.DYNAMIC_EMBEDDING_VARIABLES, embeddings) vars_to_warm_start = [embeddings] if use_regex: vars_to_warm_start = [".*t200.*"] restore_op = de.warm_start(ckpt_to_initialize_from=ckpt_prefix, vars_to_warm_start=vars_to_warm_start, var_name_to_prev_var_name={ "restore_{}".format(emb_name): "save_{}".format(emb_name) }) self.evaluate(restore_op) self.assertAllEqual(emb, val_list)
def _model_fn(features, labels, mode, params): ids = features['ids'] embeddings = de.get_variable(emb_name, dtypes.int64, dtypes.float32, devices=devices, initializer=0.0) emb = de.embedding_lookup(embeddings, ids, name="lookup") emb.graph.add_to_collection( de.GraphKeys.DYNAMIC_EMBEDDING_VARIABLES, embeddings) vars_to_warm_start = [embeddings] if use_regex: vars_to_warm_start = [".*t300.*"] warm_start_hook = de.WarmStartHook( ckpt_to_initialize_from=ckpt_prefix, vars_to_warm_start=vars_to_warm_start) return tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=emb, prediction_hooks=[warm_start_hook])
def test_embedding_lookup_sparse_with_initializer(self): id = 0 embed_dim = 8 elements_num = 262144 for initializer, target_mean, target_stddev in [ (init_ops.random_normal_initializer(0.0, 0.001), 0.0, 0.001), (init_ops.truncated_normal_initializer(0.0, 0.001), 0.0, 0.00088), (keras_init_ops.RandomNormalV2(mean=0.0, stddev=0.001), 0.0, 0.001), ]: with self.session(config=default_config, use_gpu=test_util.is_gpu_available()): id += 1 embedding_weights = de.get_variable( "emb-init-bugfix-" + str(id), key_dtype=dtypes.int64, value_dtype=dtypes.float32, devices=_get_devices() * 3, initializer=initializer, dim=embed_dim, ) ids = np.random.randint( -0x7FFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF, elements_num, dtype=np.int64, ) ids = np.unique(ids) ids = constant_op.constant(ids, dtypes.int64) vals_op = de.embedding_lookup(embedding_weights, ids, None).eval() mean = self.evaluate(math_ops.reduce_mean(vals_op)) stddev = self.evaluate(math_ops.reduce_std(vals_op)) rtol = 2e-5 atol = rtol self.assertTrue(not (list(vals_op[0]) == list(vals_op[1]))) self.assertAllClose(target_mean, mean, rtol, atol) self.assertAllClose(target_stddev, stddev, rtol, atol)
def create_slots(primary, init, slot_name, op_name): """Helper function for creating a slot variable for statefull optimizers.""" # lwk @primary是一个de.TrainableWrapper # lwk 所以@params_var_是de.Variable,params_ids_是特征ID params_var_, params_ids_ = primary.params, primary.ids scope_store = variable_scope._get_default_variable_store() full_name = params_var_.name + "/" + op_name + "/" + slot_name if full_name not in scope_store._vars: with ops.colocate_with(primary, ignore_existing=True): # lwk de.Variable的slot variable也是一个de.Variable,也就是是一个哈希表 slot_variable_ = de.Variable( name=full_name, key_dtype=params_var_.key_dtype, value_dtype=params_var_.value_dtype, dim=params_var_.dim, # 维度一样 devices=params_var_.devices, partitioner=params_var_.partition_fn, initializer=init, trainable=False, checkpoint=params_var_.checkpoint, ) scope_store._vars[full_name] = slot_variable_ # lwk 所以动态哈希表的原理,本质上就是将lookup的结果包装成一个variable # lwk 这个variable包含了哈希表和ids,优化器也使用这些信息,来构造一个对应的slot variable slot_trainable = None _, slot_trainable = de.embedding_lookup( params=scope_store._vars[full_name], ids=params_ids_, name=slot_name, return_trainable=True, ) return slot_trainable
def common_minimize_trainable(self, base_opt, test_opt, name): base_opt = de.DynamicEmbeddingOptimizer(base_opt) test_opt = de.DynamicEmbeddingOptimizer(test_opt) id = 0 for ( num_shards, k_dtype, d_dtype, initial_mode, dim, run_step, ) in itertools.product( [1, 2], [ dtypes.int64, ], [ dtypes.float32, ], [ "constant", ], [1, 10], [10], ): id += 1 with self.session(use_gpu=test_util.is_gpu_available(), config=default_config) as sess: # common define raw_init_ids = [0, 1] raw_init_vals = np.random.rand(2, dim) raw_ids = [ 0, ] x = constant_op.constant(np.random.rand(dim, len(raw_ids)), dtype=d_dtype) # base graph base_var = resource_variable_ops.ResourceVariable( raw_init_vals, dtype=d_dtype) ids = constant_op.constant(raw_ids, dtype=k_dtype) pred0 = math_ops.matmul( embedding_ops.embedding_lookup([base_var], ids), x) loss0 = pred0 * pred0 base_opt_op = base_opt.minimize(loss0) # test graph embeddings = de.get_variable( "t2020-" + name + str(id), key_dtype=k_dtype, value_dtype=d_dtype, devices=_get_devices() * num_shards, initializer=1.0, dim=dim, ) self.device_check(embeddings) init_ids = constant_op.constant(raw_init_ids, dtype=k_dtype) init_vals = constant_op.constant(raw_init_vals, dtype=d_dtype) init_op = embeddings.upsert(init_ids, init_vals) self.evaluate(init_op) test_var, trainable = de.embedding_lookup( [embeddings], ids, return_trainable=True) pred1 = math_ops.matmul(test_var, x) loss1 = pred1 * pred1 test_opt_op = test_opt.minimize(loss1, var_list=[trainable]) self.evaluate(variables.global_variables_initializer()) for _ in range(run_step): sess.run(base_opt_op) # Fetch params to validate initial values self.assertAllCloseAccordingToType(raw_init_vals[raw_ids], self.evaluate(test_var)) # Run `run_step` step of sgd for _ in range(run_step): sess.run(test_opt_op) table_var = embeddings.lookup(ids) # Validate updated params self.assertAllCloseAccordingToType( self.evaluate(base_var)[raw_ids], self.evaluate(table_var), msg="Cond:{},{},{},{},{},{}".format( num_shards, k_dtype, d_dtype, initial_mode, dim, run_step), )
def test_traing_save_restore(self): opt = de.DynamicEmbeddingOptimizer(adam.AdamOptimizer(0.3)) id = 0 if test_util.is_gpu_available(): dim_list = [1, 2, 4, 8, 10, 16, 32, 64, 100, 256, 500] else: dim_list = [10] for key_dtype, value_dtype, dim, step in itertools.product( [dtypes.int64], [dtypes.float32], dim_list, [10], ): id += 1 save_dir = os.path.join(self.get_temp_dir(), "save_restore") save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") ids = script_ops.py_func(_create_dynamic_shape_tensor(), inp=[], Tout=key_dtype, stateful=True) params = de.get_variable( name="params-test-0915-" + str(id), key_dtype=key_dtype, value_dtype=value_dtype, initializer=init_ops.random_normal_initializer(0.0, 0.01), dim=dim, ) _, var0 = de.embedding_lookup(params, ids, return_trainable=True) def loss(): return var0 * var0 params_keys, params_vals = params.export() mini = opt.minimize(loss, var_list=[var0]) opt_slots = [opt.get_slot(var0, _s) for _s in opt.get_slot_names()] _saver = saver.Saver([params] + [_s.params for _s in opt_slots]) with self.session(config=default_config, use_gpu=test_util.is_gpu_available()) as sess: self.evaluate(variables.global_variables_initializer()) for _i in range(step): self.evaluate([mini]) size_before_saved = self.evaluate(params.size()) np_params_keys_before_saved = self.evaluate(params_keys) np_params_vals_before_saved = self.evaluate(params_vals) opt_slots_kv_pairs = [_s.params.export() for _s in opt_slots] np_slots_kv_pairs_before_saved = [ self.evaluate(_kv) for _kv in opt_slots_kv_pairs ] _saver.save(sess, save_path) with self.session(config=default_config, use_gpu=test_util.is_gpu_available()) as sess: self.evaluate(variables.global_variables_initializer()) self.assertAllEqual(0, self.evaluate(params.size())) _saver.restore(sess, save_path) params_keys_restored, params_vals_restored = params.export() size_after_restored = self.evaluate(params.size()) np_params_keys_after_restored = self.evaluate( params_keys_restored) np_params_vals_after_restored = self.evaluate( params_vals_restored) opt_slots_kv_pairs_restored = [ _s.params.export() for _s in opt_slots ] np_slots_kv_pairs_after_restored = [ self.evaluate(_kv) for _kv in opt_slots_kv_pairs_restored ] self.assertAllEqual(size_before_saved, size_after_restored) self.assertAllEqual( np.sort(np_params_keys_before_saved), np.sort(np_params_keys_after_restored), ) self.assertAllEqual( np.sort(np_params_vals_before_saved, axis=0), np.sort(np_params_vals_after_restored, axis=0), ) for pairs_before, pairs_after in zip( np_slots_kv_pairs_before_saved, np_slots_kv_pairs_after_restored): self.assertAllEqual( np.sort(pairs_before[0], axis=0), np.sort(pairs_after[0], axis=0), ) self.assertAllEqual( np.sort(pairs_before[1], axis=0), np.sort(pairs_after[1], axis=0), ) if test_util.is_gpu_available(): self.assertTrue( "GPU" in params.tables[0].resource_handle.device)
def model_fn(features, labels, mode, params): #logging.info('mode: %s, labels: %s, params: %s, features: %s', mode, labels, params, features) if params["args"].get("addon_embedding"): import tensorflow_recommenders_addons as tfra import tensorflow_recommenders_addons.dynamic_embedding as dynamic_embedding else: import tensorflow.dynamic_embedding as dynamic_embedding features.update(labels) logging.info("------ build hyper parameters -------") embedding_size = params["parameters"]["embedding_size"] learning_rate = params["parameters"]["learning_rate"] use_bn = params["parameters"]["use_bn"] feat = params['features'] sparse_feat_list = list(set(feat["sparse"]) - set(SPARSE_MASK)) if 'sparse' in feat else [] sparse_seq_feat_list = list( set(feat["sparse_seq"]) - set(SPARSE_SEQ_MASK)) if 'sparse_seq' in feat else [] sparse_seq_feat_list = [] #dense_feat_list = list(set(feat["dense"]) - set(DENSE_MASK)) if 'dense' in feat else [] # hashtable v1/v2 image均无法同时关bn和mask dense_feat dense_feat_list = [] dense_seq_feat_list = list( set(feat["dense_seq"]) - set(DENSE_SEQ_MASK)) if 'dense_seq' in feat else [] sparse_feat_num = len(sparse_feat_list) sparse_seq_num = len(sparse_seq_feat_list) dense_feat_num = len(dense_feat_list) dense_seq_feat_num = len(dense_seq_feat_list) all_features = (sparse_feat_list + sparse_seq_feat_list + dense_feat_list + dense_seq_feat_list) batch_size = tf.shape(features[goods_id_feat])[0] logging.info("------ show batch_size: {} -------".format(batch_size)) level_0_feats = list( set(params.get('level_0_feat_list')) & set(all_features)) logging.info('level_0_feats: {}'.format(level_0_feats)) new_features = dict() if params["args"].get("level_flag") and params["args"].get( "job_type") == "export": for feature_name in features: if feature_name in level_0_feats: new_features[feature_name] = tf.reshape( tf.tile(tf.reshape(features[feature_name], [1, -1]), [batch_size, 1]), [batch_size, -1]) else: new_features[feature_name] = features[feature_name] features = new_features l2_reg = params["parameters"]["l2_reg"] is_training = True if mode == tf.estimator.ModeKeys.TRAIN else False has_label = True if 'is_imp' in features else False logging.info("is_training: {}, has_label: {}, features: {}".format( is_training, has_label, features)) logging.info("------ build embedding -------") # def partition_fn(keys, shard_num=params["parameters"]["ps_nums"]): # return tf.cast(keys % shard_num, dtype=tf.int32) if is_training: devices_info = [ "/job:ps/replica:0/task:{}/CPU:0".format(i) for i in range(params["parameters"]["ps_num"]) ] initializer = tf.compat.v1.truncated_normal_initializer(0.0, 1e-2) else: devices_info = [ "/job:localhost/replica:0/task:{}/CPU:0".format(0) for i in range(params["parameters"]["ps_num"]) ] initializer = tf.compat.v1.zeros_initializer() logging.info("------ dynamic_embedding devices_info is {}-------".format( devices_info)) if mode == tf.estimator.ModeKeys.PREDICT: dynamic_embedding.enable_inference_mode() deep_dynamic_variables = dynamic_embedding.get_variable( name="deep_dynamic_embeddings", devices=devices_info, initializer=initializer, # partitioner=partition_fn, dim=embedding_size, trainable=is_training, #init_size=INIT_SIZE ) sparse_feat = None sparse_unique_ids = None if sparse_feat_num > 0: logging.info("------ build sparse feature -------") id_list = sorted(sparse_feat_list) ft_sparse_idx = tf.concat( [tf.reshape(features[str(i)], [-1, 1]) for i in id_list], axis=1) sparse_unique_ids, sparse_unique_idx = tf.unique( tf.reshape(ft_sparse_idx, [-1])) sparse_weights = dynamic_embedding.embedding_lookup( params=deep_dynamic_variables, ids=sparse_unique_ids, name="deep_sparse_weights") if params["args"].get("zero_padding"): sparse_weights = tf.reshape(sparse_weights, [-1, embedding_size]) sparse_weights = tf.where( tf.not_equal( tf.expand_dims(sparse_unique_ids, axis=1), tf.zeros_like(tf.expand_dims(sparse_unique_ids, axis=1))), sparse_weights, tf.zeros_like(sparse_weights)) sparse_weights = tf.gather(sparse_weights, sparse_unique_idx) sparse_feat = tf.reshape( sparse_weights, shape=[batch_size, sparse_feat_num * embedding_size]) sparse_seq_feat = None sparse_seq_unique_ids = None if sparse_seq_num > 0: logging.info("---- build sparse seq feature ---") if params["args"].get("merge_sparse_seq"): sparse_seq_name_list = sorted( sparse_seq_feat_list) #[B, s1], [B, s2], ... [B, sn] ft_sparse_seq_ids = tf.concat( [ tf.reshape(features[str(i)], [batch_size, -1]) for i in sparse_seq_name_list ], axis=1) #[B, [s1, s2, ...sn]] => [B, per_seq_len*seq_num] sparse_seq_unique_ids, sparse_seq_unique_idx = tf.unique( tf.reshape(ft_sparse_seq_ids, [-1])) #[u], [B*per_seq_len*seq_num] sparse_seq_weights = dynamic_embedding.embedding_lookup( params=deep_dynamic_variables, ids=sparse_seq_unique_ids, name="deep_sparse_seq_weights") #[u, e] deep_embed_seq = tf.where( tf.not_equal( tf.expand_dims(sparse_seq_unique_ids, axis=1), tf.zeros_like(tf.expand_dims(sparse_seq_unique_ids, axis=1))), sparse_seq_weights, tf.zeros_like(sparse_seq_weights)) #[u, e] deep_embedding_seq = tf.reshape( tf.gather(deep_embed_seq, sparse_seq_unique_idx), #[B*per_seq_len*seq_num, e] shape=[batch_size, sparse_seq_num, -1, embedding_size]) #[B, seq_num, per_seq_len, e] if params["parameters"]["combiner"] == "sum": tmp_feat = tf.reduce_sum(deep_embedding_seq, axis=2) else: tmp_feat = tf.reduce_mean(deep_embedding_seq, axis=2) sparse_seq_feat = tf.reshape( tmp_feat, [batch_size, sparse_seq_num * embedding_size]) #[B, seq_num*e] else: sparse_seq_feats = [] sparse_ids = [] for sparse_seq_name in sparse_seq_feat_list: sp_ids = features[sparse_seq_name] if params["args"].get("zero_padding2"): sparse_seq_unique_ids, sparse_seq_unique_idx, _ = tf.unique_with_counts( tf.reshape(sp_ids, [-1])) deep_sparse_seq_weights = tf.reshape( dynamic_embedding.embedding_lookup( params=deep_dynamic_variables, ids=sparse_seq_unique_ids, name="deep_sparse_weights_{}".format( sparse_seq_name)), [-1, embedding_size]) deep_embed_seq = tf.where( tf.not_equal( tf.expand_dims(sparse_seq_unique_ids, axis=1), tf.zeros_like( tf.expand_dims(sparse_seq_unique_ids, axis=1))), deep_sparse_seq_weights, tf.zeros_like(deep_sparse_seq_weights)) deep_embedding_seq = tf.reshape( tf.gather(deep_embed_seq, sparse_seq_unique_idx), shape=[batch_size, -1, embedding_size]) if params["parameters"]["combiner"] == "sum": tmp_feat = tf.reduce_sum(deep_embedding_seq, axis=1) else: tmp_feat = tf.reduce_mean(deep_embedding_seq, axis=1) sparse_ids.append(sparse_seq_unique_ids) sparse_seq_feats.append( tf.reshape(tmp_feat, [batch_size, embedding_size])) else: tmp_feat = dynamic_embedding.safe_embedding_lookup_sparse( embedding_weights=deep_dynamic_variables, sparse_ids=sp_ids, combiner=params["parameters"]["combiner"], name="safe_embedding_lookup_sparse") temp_uni_id, _, _ = tf.unique_with_counts( tf.reshape(sp_ids.values, [-1])) sparse_ids.append(temp_uni_id) sparse_seq_feats.append( tf.reshape(tmp_feat, [batch_size, embedding_size])) sparse_seq_feat = tf.concat(sparse_seq_feats, axis=1) sparse_seq_unique_ids, _ = tf.unique(tf.concat(sparse_ids, axis=0)) dense_feat = None if dense_feat_num > 0: logging.info("------ build dense feature -------") den_id_list = sorted(dense_feat_list) dense_feat_base = tf.concat( [tf.reshape(features[str(i)], [-1, 1]) for i in den_id_list], axis=1) #deep_dense_w1 = tf.compat.v1.get_variable('deep_dense_w1', # tf.TensorShape([dense_feat_num]), # initializer=tf.compat.v1.truncated_normal_initializer( # 2.0 / math.sqrt(dense_feat_num)), # dtype=tf.float32) #deep_dense_w2 = tf.compat.v1.get_variable('deep_dense_w2', # tf.TensorShape([dense_feat_num]), # initializer=tf.compat.v1.truncated_normal_initializer( # 2.0 / math.sqrt(dense_feat_num)), # dtype=tf.float32) #w1 = tf.tile(tf.expand_dims(deep_dense_w1, axis=0), [tf.shape(dense_feat_base)[0], 1]) #dense_input_1 = tf.multiply(dense_feat_base, w1) #dense_feat = dense_input_1 dense_feat = dense_feat_base dense_seq_feat = None if dense_seq_feat_num > 0: logging.info("------ build dense seq feature -------") den_seq_id_list = sorted(dense_seq_feat_list) dense_seq_feat = tf.concat([ tf.reshape(features[str(i[0])], [-1, i[1]]) for i in den_seq_id_list ], axis=1) logging.info("------ join all feature -------") fc_inputs = tf.concat([ x for x in [sparse_feat, sparse_seq_feat, dense_feat, dense_seq_feat] if x is not None ], axis=1) logging.info("---- tracy debug input is ----") logging.info(sparse_feat) logging.info(sparse_seq_feat) logging.info(dense_feat) logging.info(dense_seq_feat) logging.info(fc_inputs) logging.info("------ join fc -------") for idx, units in enumerate(params["parameters"]["hidden_units"]): fc_inputs = fully_connected_with_bn_ahead( inputs=fc_inputs, num_outputs=units, l2_reg=l2_reg, scope="out_mlp_{}".format(idx), activation_fn=tf.nn.relu, train_phase=is_training, use_bn=use_bn) y_deep_ctr = fully_connected_with_bn_ahead(inputs=fc_inputs, num_outputs=1, activation_fn=tf.identity, l2_reg=l2_reg, scope="ctr_mlp", train_phase=is_training, use_bn=use_bn) logging.info("------ build ctr out -------") sample_rate = params["args"]["sample_rate"] logit = tf.reshape(y_deep_ctr, shape=[-1], name="logit") sample_logit = get_sample_logits(logit, sample_rate) pred_ctr = tf.nn.sigmoid(logit, name="pred_ctr") sample_pred_ctr = tf.nn.sigmoid(sample_logit, name="sample_pred_ctr") logging.info("------ build predictions -------") preds = { 'p_ctr': tf.reshape(pred_ctr, shape=[-1, 1]), } logging.info("---- deep_dynamic_variables.size ----") logging.info(deep_dynamic_variables.size()) size = tf.identity(deep_dynamic_variables.size(), name="size") label_col = "is_clk" if params["args"].get("set_train_labels"): label_col = params["args"]["set_train_labels"]["1"] logging.info( "------ build labels, label_col: {} -------".format(label_col)) if has_label: labels_ctr = tf.reshape(features["is_clk"], shape=[-1], name="labels_ctr") if mode == tf.estimator.ModeKeys.PREDICT: logging.info("---- build tf-serving predict ----") pred_cvr = tf.fill(tf.shape(pred_ctr), 1.0) preds.update({ 'labels_cart': tf.reshape(pred_cvr, shape=[-1, 1]), 'p_car': tf.reshape(features["dense_1608"], shape=[-1, 1]), 'labels_cvr': tf.reshape(pred_cvr, shape=[-1, 1]), 'p_cvr': tf.reshape(pred_cvr, shape=[-1, 1]), }) if 'logid' in features: preds.update( {'logid': tf.reshape(features["logid"], shape=[-1, 1])}) if has_label: logging.info("------ build offline label -------") preds["labels_ctr"] = tf.reshape(labels_ctr, shape=[-1, 1]) export_outputs = { "predict_export_outputs": tf.estimator.export.PredictOutput(outputs=preds) } return tf.estimator.EstimatorSpec(mode, predictions=preds, export_outputs=export_outputs) logging.info("----all vars:-----" + str(tf.compat.v1.global_variables())) for var in tf.compat.v1.trainable_variables(): logging.info("----trainable------" + str(var)) logging.info("------ build metric -------") #loss = tf.reduce_mean( # tf.compat.v1.losses.log_loss(labels=labels_ctr, predictions=sample_pred_ctr), # name="loss") loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( labels=labels_ctr, logits=sample_logit), name="loss") ctr_auc = tf.compat.v1.metrics.auc(labels=labels_ctr, predictions=sample_pred_ctr, name="ctr_auc") label_ctr_avg = tf.reduce_mean(labels_ctr, name="label_ctr_avg") real_pred_ctr_avg = tf.reduce_mean(pred_ctr, name="real_pred_ctr_avg") sample_pred_ctr_avg = tf.reduce_mean(sample_pred_ctr, name="pred_ctr_avg") sample_pred_bias_avg = tf.add(sample_pred_ctr_avg, tf.negative(label_ctr_avg), name="pred_bias_avg") tf.compat.v1.summary.histogram('labels_ctr', labels_ctr) tf.compat.v1.summary.histogram('pred_ctr', sample_pred_ctr) tf.compat.v1.summary.histogram('real_pred_ctr', pred_ctr) tf.compat.v1.summary.scalar('label_ctr_avg', label_ctr_avg) tf.compat.v1.summary.scalar('pred_ctr_avg', sample_pred_ctr_avg) tf.compat.v1.summary.scalar('real_pred_ctr_avg', real_pred_ctr_avg) tf.compat.v1.summary.scalar('pred_bias_avg', sample_pred_bias_avg) tf.compat.v1.summary.scalar('loss', loss) tf.compat.v1.summary.scalar('ctr_auc', ctr_auc[1]) logging.info("------ compute l2 reg -------") if params["parameters"]["use_l2"]: all_unique_ids, _ = tf.unique( tf.concat([ x for x in [sparse_unique_ids, sparse_seq_unique_ids] if x is not None ], axis=0)) all_unique_ids_w = dynamic_embedding.embedding_lookup( deep_dynamic_variables, all_unique_ids, name="unique_ids_weights", return_trainable=False) embed_loss = l2_reg * tf.nn.l2_loss( tf.reshape(all_unique_ids_w, shape=[-1, embedding_size])) + tf.reduce_sum( tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)) tf.compat.v1.summary.scalar('embed_loss', embed_loss) loss = loss + embed_loss loss = tf.identity(loss, name="total_loss") tf.compat.v1.summary.scalar('total_loss', loss) if mode == tf.estimator.ModeKeys.EVAL: logging.info("------ EVAL -------") eval_metric_ops = { "ctr_auc_eval": ctr_auc, } if has_label: logging.info("------ build offline label -------") preds["labels_ctr"] = tf.reshape(labels_ctr, shape=[-1, 1]) export_outputs = { "predict_export_outputs": tf.estimator.export.PredictOutput(outputs=preds) } return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops, export_outputs=export_outputs) logging.info("---- Learning rate ----") lr = get_learning_rate(params["parameters"]["learning_rate"], params["parameters"]["use_decay"]) if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.compat.v1.train.get_global_step() logging.info("------ TRAIN -------") optimizer_type = params["parameters"].get('optimizer', 'Adam') if optimizer_type == 'Sgd': optimizer = tf.compat.v1.train.GradientDescentOptimizer( learning_rate=lr) elif optimizer_type == 'Adagrad': optimizer = tf.compat.v1.train.AdagradOptimizer(learning_rate=lr) elif optimizer_type == 'Rmsprop': optimizer = tf.compat.v1.train.RMSPropOptimizer(learning_rate=lr) elif optimizer_type == 'Ftrl': optimizer = tf.compat.v1.train.FtrlOptimizer(learning_rate=lr) elif optimizer_type == 'Momentum': optimizer = tf.compat.v1.train.MomentumOptimizer(learning_rate=lr, momentum=0.9) else: optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=lr, beta1=0.9, beta2=0.999, epsilon=1e-8) if params["args"].get("addon_embedding"): optimizer = dynamic_embedding.DynamicEmbeddingOptimizer(optimizer) train_op = optimizer.minimize(loss, global_step=global_step) # fix tf2 batch_normalization bug update_ops = tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.UPDATE_OPS) logging.info('train ops: {}, update ops: {}'.format( str(train_op), str(update_ops))) train_op = tf.group([train_op, update_ops]) return tf.estimator.EstimatorSpec(mode=mode, predictions=preds, loss=loss, train_op=train_op)