예제 #1
0
  def test_sharded_multi_lookup_on_one_variable(self):
    embeddings = de.get_variable(
        "t340",
        dtypes.int64,
        dtypes.float32,
        devices=_get_devices() * 3,
        initializer=2.0,
    )

    ids = constant_op.constant([0, 1, 2, 3, 4], dtype=dtypes.int64)
    vals = constant_op.constant([[0.0], [1.0], [2.0], [3.0], [4.0]],
                                dtype=dtypes.float32)
    new_vals = constant_op.constant([[10.0], [11.0], [12.0], [13.0], [14.0]],
                                    dtype=dtypes.float32)

    ids0 = constant_op.constant([1, 3, 2], dtype=dtypes.int64)
    ids1 = constant_op.constant([3, 4], dtype=dtypes.int64)

    embedding0 = de.embedding_lookup(embeddings, ids0)
    embedding1 = de.embedding_lookup(embeddings, ids1)

    with self.session(use_gpu=test_util.is_gpu_available(),
                      config=default_config):
      self.evaluate(embeddings.upsert(ids, vals))
      self.assertAllClose(embedding0.eval(), [[1.0], [3.0], [2.0]])
      self.assertAllEqual([3, 1], embedding0.eval().shape)
      self.assertAllClose(embedding1.eval(), [[3.0], [4.0]])
      self.assertAllEqual([2, 1], embedding1.eval().shape)
      self.evaluate(embeddings.upsert(ids, new_vals))
      self.assertAllClose(embedding1.eval(), [[13.0], [14.0]])
      self.assertAllEqual([2, 1], embedding1.eval().shape)
    def test_scope_reuse_embedding_lookup(self):
        ids = constant_op.constant([0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
                                   dtype=dtypes.int64)
        with variable_scope.variable_scope("test",
                                           reuse=variable_scope.AUTO_REUSE):
            p1 = de.get_variable(name="p1")
            with variable_scope.variable_scope("q"):
                _, t1 = de.embedding_lookup(p1,
                                            ids,
                                            name="emb",
                                            return_trainable=True)

        with variable_scope.variable_scope("test",
                                           reuse=variable_scope.AUTO_REUSE):
            p1_reuse = de.get_variable(name="p1")
            p2 = de.get_variable(name="p2")
            with variable_scope.variable_scope("q"):
                _, t2 = de.embedding_lookup(p2,
                                            ids,
                                            name="emb",
                                            return_trainable=True)

        self.assertAllEqual(p1.name, "test/p1")
        self.assertAllEqual(p2.name, "test/p2")
        self.assertAllEqual(p1, p1_reuse)
        self.assertEqual(t1.name, "test/q/emb/TrainableWrapper:0")
        self.assertEqual(t2.name, "test/q/emb/TrainableWrapper_1:0")
        self.assertAllEqual(p1._tables[0].name, "test_p1_mht_1of1")
        self.assertAllEqual(p1_reuse._tables[0].name, "test_p1_mht_1of1")
        self.assertAllEqual(p2._tables[0].name, "test_p2_mht_1of1")
예제 #3
0
  def test_embedding_lookup_shape(self):

    def _evaluate(tensors, feed_dict):
      sess = ops.get_default_session()
      if sess is None:
        with self.test_session() as sess:
          return sess.run(tensors, feed_dict=feed_dict)
      else:
        return sess.run(tensors, feed_dict=feed_dict)

    with self.session(use_gpu=test_util.is_gpu_available(),
                      config=default_config):
      default_val = -1

      keys = constant_op.constant([0, 1, 2], dtypes.int64)
      values = constant_op.constant([[0, 0, 0], [1, 1, 1], [2, 2, 2]],
                                    dtypes.int32)
      table = de.get_variable("t140",
                              dtypes.int64,
                              dtypes.int32,
                              dim=3,
                              initializer=default_val)
      self.evaluate(table.upsert(keys, values))
      self.assertAllEqual(3, self.evaluate(table.size()))

      # shape of ids is fully defined
      ids = constant_op.constant([[0, 1], [2, 4]], dtypes.int64)
      embeddings = de.embedding_lookup(table, ids)
      self.assertAllEqual([2, 2, 3], embeddings.get_shape())
      re = self.evaluate(embeddings)
      self.assertAllEqual([[[0, 0, 0], [1, 1, 1]], [[2, 2, 2], [-1, -1, -1]]],
                          re)

      # shape of ids is partially defined
      ids = gen_array_ops.placeholder(shape=(2, None), dtype=dtypes.int64)
      embeddings = de.embedding_lookup(table, ids)
      self.assertFalse(embeddings.get_shape().is_fully_defined())
      re = _evaluate(
          embeddings,
          feed_dict={ids: np.asarray([[0, 1], [2, 4]], dtype=np.int64)})
      self.assertAllEqual([[[0, 0, 0], [1, 1, 1]], [[2, 2, 2], [-1, -1, -1]]],
                          re)

      # shape of ids is unknown
      ids = gen_array_ops.placeholder(dtype=dtypes.int64)
      embeddings = de.embedding_lookup(table, ids)
      self.assertEqual(embeddings.get_shape(), tensor_shape.unknown_shape())
      re = _evaluate(
          embeddings,
          feed_dict={ids: np.asarray([[0, 1], [2, 4]], dtype=np.int64)})
      self.assertAllEqual([[[0, 0, 0], [1, 1, 1]], [[2, 2, 2], [-1, -1, -1]]],
                          re)
    def test_higher_rank(self):
        np.random.seed(8)
        with self.session(use_gpu=test_util.is_gpu_available(),
                          config=default_config):
            for dim in [1, 10]:
                for ids_shape in [[3, 2], [4, 3], [4, 3, 10]]:
                    with variable_scope.variable_scope("test_higher_rank",
                                                       reuse=True):
                        params = de.get_variable(
                            "t350-" + str(dim),
                            dtypes.int64,
                            dtypes.float32,
                            initializer=2.0,
                            dim=dim,
                        )
                        ids = np.random.randint(
                            2**31, size=np.prod(ids_shape),
                            dtype=np.int).reshape(ids_shape)
                        ids = constant_op.constant(ids, dtype=dtypes.int64)
                        simple = params.lookup(ids)
                        self.evaluate(params.upsert(ids, simple))

                        embedding = de.embedding_lookup(params, ids)
                        self.assertAllEqual(simple.eval(), embedding.eval())
                        self.assertAllEqual(ids_shape + [dim],
                                            embedding.eval().shape)
예제 #5
0
  def test_sharded_custom_partitioner_int32_ids(self):

    def _partition_fn(keys, shard_num):
      return math_ops.cast(keys % 2, dtype=dtypes.int32)

    embeddings = de.get_variable(
        "t330",
        dtypes.int64,
        dtypes.float32,
        partitioner=_partition_fn,
        devices=_get_devices() * 3,
        initializer=2.0,
    )

    ids = constant_op.constant([0, 1, 2, 3, 4], dtype=dtypes.int64)
    vals = constant_op.constant([[0.0], [1.0], [2.0], [3.0], [4.0]],
                                dtype=dtypes.float32)
    ids_test = constant_op.constant([1, 3, 2, 3, 0], dtype=dtypes.int64)
    embedding = de.embedding_lookup(embeddings, ids_test)
    with self.session(use_gpu=test_util.is_gpu_available(),
                      config=default_config):
      self.evaluate(embeddings.upsert(ids, vals))
      self.assertAllClose(embedding.eval(), [[1.0], [3.0], [2.0], [3.0], [0.0]])
      self.assertAllEqual([5, 1], embedding.eval().shape)
      self.assertAllEqual(3, embeddings.size(0).eval())
      self.assertAllEqual(2, embeddings.size(1).eval())
      self.assertAllEqual(0, embeddings.size(2).eval())
 def test_check_ops_number(self):
   self.assertTrue(de.get_model_mode() == "train")
   de.enable_inference_mode()
   self.assertTrue(de.get_model_mode() == "inference")
   de.enable_train_mode()
   self.assertTrue(de.get_model_mode() == "train")
   for fn, assign_num, read_num in [(de.enable_train_mode, 1, 2),
                                    (de.enable_inference_mode, 0, 1)]:
     fn()
     embeddings = de.get_variable('ModeModeTest' + str(assign_num),
                                  key_dtype=dtypes.int64,
                                  value_dtype=dtypes.float32,
                                  devices=_get_devices(),
                                  initializer=1.,
                                  dim=8)
     ids = constant_op.constant([0, 1, 2, 3, 4], dtype=dtypes.int64)
     test_var, trainable = de.embedding_lookup([embeddings],
                                               ids,
                                               return_trainable=True)
     _ = math_ops.add(test_var, 1)
     op_list = ops.get_default_graph().get_operations()
     op_list_assign = [
         op.name for op in op_list if "AssignBeforeReadVariable" in op.name
     ]
     op_list_read = [op.name for op in op_list if "ReadVariableOp" in op.name]
     self.assertTrue(len(op_list_assign) == assign_num)
     self.assertTrue(len(op_list_read) == read_num)
     de.enable_train_mode()
     ops.reset_default_graph()
def create_slots(primary, init, slot_name, op_name):
    """Helper function for creating a slot variable for statefull optimizers."""
    params_var_, params_ids_ = primary.params, primary.ids

    scope_store = variable_scope._get_default_variable_store()
    full_name = params_var_.name + "/" + op_name + "/" + slot_name
    if full_name not in scope_store._vars:
        with ops.colocate_with(primary, ignore_existing=True):
            slot_variable_ = de.Variable(
                name=full_name,
                key_dtype=params_var_.key_dtype,
                value_dtype=params_var_.value_dtype,
                dim=params_var_.dim,
                devices=params_var_.devices,
                partitioner=params_var_.partition_fn,
                initializer=init,
                init_size=params_var_.init_size,
                trainable=False,
                checkpoint=params_var_.checkpoint,
            )

        scope_store._vars[full_name] = slot_variable_

    slot_trainable = None
    _, slot_trainable = de.embedding_lookup(
        params=scope_store._vars[full_name],
        ids=params_ids_,
        name=slot_name,
        return_trainable=True,
    )

    return slot_trainable
예제 #8
0
  def test_static_shape_checking(self):
    np.random.seed(8)
    with self.session(use_gpu=test_util.is_gpu_available(),
                      config=default_config):
      for dim in [1, 10]:
        for ids_shape in [[3, 2], [4, 3], [4, 3, 10]]:
          with variable_scope.variable_scope(
              "test_static_shape_checking" + str(dim),
              reuse=variable_scope.AUTO_REUSE,
          ):
            params = de.get_variable(
                "test_static_shape_checking-" + str(dim),
                dtypes.int64,
                dtypes.float32,
                initializer=2.0,
                dim=dim,
            )
            params_nn = variable_scope.get_variable("n",
                                                    shape=[100, dim],
                                                    use_resource=False)
            ids = np.random.randint(2**31,
                                    size=np.prod(ids_shape),
                                    dtype=np.int).reshape(ids_shape)
            ids = constant_op.constant(ids, dtype=dtypes.int64)

            embedding_test = de.embedding_lookup(params, ids)
            embedding_base = embedding_ops.embedding_lookup(params_nn, ids)
            self.assertAllEqual(embedding_test.shape, embedding_base.shape)
    def commonly_apply_restriction_verify(self, optimizer):
        first_inputs = np.array(range(6), dtype=np.int64)
        second_inputs = np.array(range(4, 9), dtype=np.int64)
        overdue_features = np.array(range(4), dtype=np.int64)
        updated_features = np.array(range(4, 9), dtype=np.int64)
        all_input_features = np.array(range(9), dtype=np.int64)
        embedding_dim = 2
        oversize_trigger = 100
        optimizer = de.DynamicEmbeddingOptimizer(optimizer)

        with session.Session(config=default_config) as sess:
            ids = array_ops.placeholder(dtypes.int64)
            var = de.get_variable('sp_var',
                                  key_dtype=ids.dtype,
                                  value_dtype=dtypes.float32,
                                  initializer=-0.1,
                                  dim=embedding_dim,
                                  restrict_policy=de.TimestampRestrictPolicy)
            embed_w, trainable = de.embedding_lookup(var,
                                                     ids,
                                                     return_trainable=True,
                                                     name='ut8900')
            loss = _simple_loss(embed_w)
            train_op = optimizer.minimize(loss, var_list=[trainable])

            slot_params = [
                optimizer.get_slot(trainable, name).params
                for name in optimizer.get_slot_names()
            ]
            all_vars = [var] + slot_params + [var.restrict_policy.status]

            sess.run(variables.global_variables_initializer())

            sess.run([train_op], feed_dict={ids: first_inputs})
            time.sleep(1)
            sess.run([train_op], feed_dict={ids: second_inputs})
            for v in all_vars:
                self.assertAllEqual(sess.run(v.size()), 9)
            keys, tstp = sess.run(var.restrict_policy.status.export())
            kvs = sorted(dict(zip(keys, tstp)).items())
            tstp = np.array([x[1] for x in kvs])
            for x in tstp[overdue_features]:
                for y in tstp[updated_features]:
                    self.assertLess(x, y)

            sess.run(
                var.restrict_policy.apply_restriction(
                    len(updated_features), trigger=oversize_trigger))
            for v in all_vars:
                self.assertAllEqual(sess.run(v.size()),
                                    len(all_input_features))

            sess.run(
                var.restrict_policy.apply_restriction(
                    len(updated_features), trigger=len(updated_features)))
            for v in all_vars:
                self.assertAllEqual(sess.run(v.size()), len(updated_features))
            keys, _ = sess.run(var.export())
            keys_sorted = np.sort(keys)
            self.assertAllEqual(keys_sorted, updated_features)
    def commonly_apply_update_verify_v2(self):
        if not context.executing_eagerly():
            self.skipTest('Skip graph mode test.')

        first_inputs = np.array(range(6), dtype=np.int64)
        second_inputs = np.array(range(3, 9), dtype=np.int64)
        overdue_features = np.array([0, 1, 2, 6, 7, 8], dtype=np.int64)
        updated_features = np.array(range(3, 6), dtype=np.int64)
        all_features = np.array(range(9), dtype=np.int64)

        with self.session(config=default_config):
            var = de.get_variable('sp_var',
                                  key_dtype=dtypes.int64,
                                  value_dtype=dtypes.float32,
                                  initializer=-0.1,
                                  dim=2)
            embed_w, trainable = de.embedding_lookup(var,
                                                     first_inputs,
                                                     return_trainable=True,
                                                     name='vc3939')
            policy = de.FrequencyRestrictPolicy(var)

            self.assertAllEqual(policy.status.size(), 0)
            policy.apply_update(first_inputs)
            self.assertAllEqual(policy.status.size(), len(first_inputs))
            time.sleep(1)
            policy.apply_update(second_inputs)
            self.assertAllEqual(policy.status.size(), len(all_features))

            keys, freq = policy.status.export()
            kvs = sorted(dict(zip(keys.numpy(), freq.numpy())).items())
            freq = np.array([x[1] for x in kvs])
            for x in freq[overdue_features]:
                for y in freq[updated_features]:
                    self.assertLess(x, y)
    def commonly_apply_update_verify(self):
        first_inputs = np.array(range(3), dtype=np.int64)
        second_inputs = np.array(range(1, 4), dtype=np.int64)
        overdue_features = np.array([0, 3], dtype=np.int64)
        updated_features = np.array(range(1, 3), dtype=np.int64)
        with session.Session(config=default_config) as sess:
            ids = array_ops.placeholder(dtypes.int64)
            var = de.get_variable('sp_var',
                                  key_dtype=ids.dtype,
                                  value_dtype=dtypes.float32,
                                  initializer=-0.1,
                                  dim=2)
            embed_w, trainable = de.embedding_lookup(var,
                                                     ids,
                                                     return_trainable=True,
                                                     name='pl3201')
            policy = de.FrequencyRestrictPolicy(var)
            update_op = policy.apply_update(ids)

            self.assertAllEqual(sess.run(policy.status.size()), 0)
            sess.run(update_op, feed_dict={ids: first_inputs})
            self.assertAllEqual(sess.run(policy.status.size()), 3)
            time.sleep(1)
            sess.run(update_op, feed_dict={ids: second_inputs})
            self.assertAllEqual(sess.run(policy.status.size()), 4)

            keys, freq = sess.run(policy.status.export())
            kvs = sorted(dict(zip(keys, freq)).items())
            freq = np.array([x[1] for x in kvs])
            for x in freq[overdue_features]:
                for y in freq[updated_features]:
                    self.assertLess(x, y)
 def loss_fn(var, features, trainables):
     embed_w, trainable = de.embedding_lookup(var,
                                              features,
                                              return_trainable=True)
     trainables.clear()
     trainables.append(trainable)
     return _simple_loss(embed_w)
 def loss_fn(x, trainables):
     ids = constant_op.constant(raw_ids, dtype=k_dtype)
     pred, trainable = de.embedding_lookup(
         [x], ids, return_trainable=True)
     trainables.clear()
     trainables.append(trainable)
     return pred * pred
  def test_simple_sharded(self):
    embeddings = de.get_variable(
        "t300",
        dtypes.int64,
        dtypes.float32,
        dim= 5,
        devices=_get_devices() * 2,
        initializer=2.0,
    )

    ids = constant_op.constant([0, 1, 2, 3, 4], dtype=dtypes.int64)
    embedding, trainable = de.embedding_lookup(embeddings,
                                               ids,
                                               max_norm=1.0,
                                               return_trainable=True)
    with self.session(use_gpu=test_util.is_gpu_available(),
                      config=default_config):
      self.assertAllClose(
          embedding.eval(),
          [
              [1.0],
          ] * 5,
      )
      self.evaluate(trainable.update_op())
      self.assertAllEqual(embeddings.size().eval(), 5)
      self.assertAllEqual(embeddings.size(0).eval(), 3)
      self.assertAllEqual(embeddings.size(1).eval(), 2)
예제 #15
0
 def test_max_norm_nontrivial(self):
   with self.session(use_gpu=test_util.is_gpu_available(),
                     config=default_config):
     embeddings = de.get_variable("t320",
                                  dtypes.int64,
                                  dtypes.float32,
                                  initializer=2.0,
                                  dim=2)
     fake_values = constant_op.constant([[2.0, 4.0], [3.0, 1.0]])
     ids = constant_op.constant([0, 1], dtype=dtypes.int64)
     self.evaluate(embeddings.upsert(ids, fake_values))
     embedding_no_norm = de.embedding_lookup(embeddings, ids)
     embedding = de.embedding_lookup(embeddings, ids, max_norm=2.0)
     norms = math_ops.sqrt(
         math_ops.reduce_sum(embedding_no_norm * embedding_no_norm, axis=1))
     normalized = embedding_no_norm / array_ops.stack([norms, norms], axis=1)
     self.assertAllEqual(embedding.eval(), 2 * self.evaluate(normalized))
예제 #16
0
  def test_max_norm(self):
    with self.session(use_gpu=test_util.is_gpu_available(),
                      config=default_config):
      embeddings = de.get_variable("t310",
                                   dtypes.int64,
                                   dtypes.float32,
                                   initializer=2.0)

      ids = constant_op.constant([0], dtype=dtypes.int64)
      embedding = de.embedding_lookup(embeddings, ids, max_norm=1.0)
      self.assertAllEqual(embedding.eval(), [[1.0]])
    def test_dynamic_shape_checking(self):
        np.random.seed(8)
        with self.session(use_gpu=test_util.is_gpu_available(),
                          config=default_config):
            for dim in [1, 10]:
                for ids_shape in [None, [-1, 1], [1, -1, 1], [-1, 1, 1]]:
                    with variable_scope.variable_scope(
                            "test_static_shape_checking" + str(dim),
                            reuse=variable_scope.AUTO_REUSE,
                    ):
                        params = de.get_variable(
                            "test_static_shape_checking-" + str(dim),
                            dtypes.int64,
                            dtypes.float32,
                            initializer=2.0,
                            dim=dim,
                        )
                        params_nn = variable_scope.get_variable(
                            "n", shape=[100, dim], use_resource=False)
                        ids = script_ops.py_func(
                            _create_dynamic_shape_tensor(min_val=0,
                                                         max_val=100),
                            inp=[],
                            Tout=dtypes.int64,
                            stateful=True,
                        )
                        if ids_shape is not None:
                            ids = array_ops.reshape(ids, ids_shape)

                        embedding_test = de.embedding_lookup(params, ids)
                        embedding_base = embedding_ops.embedding_lookup(
                            params_nn, ids)

                        # check static shape
                        if ids_shape is None:
                            # ids with unknown shape
                            self.assertTrue(
                                embedding_test.shape == embedding_base.shape)
                        else:
                            # ids with no fully-defined shape.
                            self.assertAllEqual(
                                embedding_test.shape.as_list(),
                                embedding_base.shape.as_list(),
                            )

                        self.evaluate(variables.global_variables_initializer())

                        # check static shape
                        for _ in range(10):
                            embedding_test_val, embedding_base_val = self.evaluate(
                                [embedding_test, embedding_base])
                            self.assertAllEqual(embedding_test_val.shape,
                                                embedding_base_val.shape)
def create_slots(primary, init, slot_name, op_name, bp_v2):
    """Helper function for creating a slot variable for statefull optimizers."""
    params_var_, params_ids_ = primary.params, primary.ids

    scope_store = variable_scope._get_default_variable_store()
    full_name = params_var_.name + "/" + op_name + "/" + slot_name
    if full_name not in scope_store._vars:
        with ops.colocate_with(primary, ignore_existing=True):
            slot_variable_ = de.Variable(
                name=full_name,
                key_dtype=params_var_.key_dtype,
                value_dtype=params_var_.value_dtype,
                dim=params_var_.dim,
                devices=params_var_.devices,
                partitioner=params_var_.partition_fn,
                initializer=init,
                kv_creator=params_var_.kv_creator,
                trainable=False,
                checkpoint=params_var_.checkpoint,
                bp_v2=bp_v2 if bp_v2 is not None else params_var_.bp_v2,
            )

        scope_store._vars[full_name] = slot_variable_
        # Record the optimizer Variable into trace.
        primary._optimizer_vars.append(slot_variable_)

    slot_trainable = None
    if context.executing_eagerly():
        slot_tw_name = slot_name + '-' + str(optimizer_v2._var_key(primary))
    else:
        # In graph mode of former version, It only uses slot_name as name to
        # trainable wrappers of slots. So here set it the name to slot_name
        # for forward compatibility.
        slot_tw_name = slot_name
    if isinstance(primary, de.shadow_ops.ShadowVariable):
        slot_trainable = de.shadow_ops.ShadowVariable(
            params=scope_store._vars[full_name],
            ids=primary.ids,
            exists=primary.exists,
            name=full_name,
            trainable=False,
        )
    else:
        _, slot_trainable = de.embedding_lookup(
            params=scope_store._vars[full_name],
            ids=params_ids_,
            name=slot_tw_name,
            return_trainable=True,
        )

    return slot_trainable
예제 #19
0
 def test_treated_as_worker_op_by_device_setter(self):
   num_ps_tasks = 2
   with ops.device("/job:worker/task:0"):
     ids = constant_op.constant([0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
                                dtype=dtypes.int64)
   setter = device_setter.replica_device_setter(ps_tasks=num_ps_tasks,
                                                ps_device="/job:ps",
                                                worker_device="/job:worker")
   with ops.device(setter):
     p1 = de.get_variable(name="p1",
                          devices=["/job:ps/task:0", "/job:ps/task:1"])
     _ = de.embedding_lookup(p1, ids, name="emb")
   self.assertTrue("/job:ps/task:0" in p1._tables[0].resource_handle.device)
   self.assertTrue("/job:ps/task:1" in p1._tables[1].resource_handle.device)
  def test_inference_numberic_correctness(self):
    train_pred = None
    infer_pred = None
    dim = 8
    initializer = init_ops.random_normal_initializer(0.0, 0.001)
    raw_init_vals = np.random.rand(100, dim)

    for fn in [de.enable_train_mode, de.enable_inference_mode]:
      with ops.Graph().as_default():
        fn()

        init_ids = constant_op.constant(list(range(100)), dtype=dtypes.int64)
        init_vals = constant_op.constant(raw_init_vals, dtype=dtypes.float32)
        with variable_scope.variable_scope("modelmode",
                                           reuse=variable_scope.AUTO_REUSE):
          embeddings = de.get_variable('ModelModeTest-numberic',
                                       key_dtype=dtypes.int64,
                                       value_dtype=dtypes.float32,
                                       devices=_get_devices() * 2,
                                       initializer=initializer,
                                       dim=dim)

          w = variables.Variable(1.0, name="w")
          _ = training_util.create_global_step()
        init_op = embeddings.upsert(init_ids, init_vals)

        ids = constant_op.constant([0, 1, 2, 3, 4], dtype=dtypes.int64)
        test_var, trainable = de.embedding_lookup([embeddings],
                                                  ids,
                                                  return_trainable=True)
        pred = math_ops.add(test_var, 1) * w
        loss = pred * pred
        opt = de.DynamicEmbeddingOptimizer(adagrad.AdagradOptimizer(0.1))
        opt.minimize(loss)

        with monitored_session.MonitoredTrainingSession(
            is_chief=True, config=default_config) as sess:
          if de.get_model_mode() == de.ModelMode.TRAIN:
            sess.run(init_op)
            train_pred = sess.run(pred)
          elif de.get_model_mode() == de.ModelMode.INFERENCE:
            sess.run(init_op)
            infer_pred = sess.run(pred)
      de.enable_train_mode()
      ops.reset_default_graph()
    self.assertAllEqual(train_pred, infer_pred)
    def _test_warm_start_rename(self, num_shards, use_regex):
        devices = ["/cpu:0" for _ in range(num_shards)]
        ckpt_prefix = os.path.join(self.get_temp_dir(), "ckpt")
        id_list = [x for x in range(100)]
        val_list = [[x] for x in range(100)]

        emb_name = "t200_{}_{}".format(num_shards, use_regex)
        with self.session(graph=ops.Graph()) as sess:
            embeddings = de.get_variable("save_{}".format(emb_name),
                                         dtypes.int64,
                                         dtypes.float32,
                                         devices=devices,
                                         initializer=0.0)
            ids = constant_op.constant(id_list, dtype=dtypes.int64)
            vals = constant_op.constant(val_list, dtype=dtypes.float32)
            self.evaluate(embeddings.upsert(ids, vals))
            save = saver.Saver(var_list=[embeddings])
            save.save(sess, ckpt_prefix)

        with self.session(graph=ops.Graph()) as sess:
            embeddings = de.get_variable("restore_{}".format(emb_name),
                                         dtypes.int64,
                                         dtypes.float32,
                                         devices=devices,
                                         initializer=0.0)
            ids = constant_op.constant(id_list, dtype=dtypes.int64)
            emb = de.embedding_lookup(embeddings, ids, name="lookup")
            sess.graph.add_to_collection(
                de.GraphKeys.DYNAMIC_EMBEDDING_VARIABLES, embeddings)
            vars_to_warm_start = [embeddings]
            if use_regex:
                vars_to_warm_start = [".*t200.*"]

            restore_op = de.warm_start(ckpt_to_initialize_from=ckpt_prefix,
                                       vars_to_warm_start=vars_to_warm_start,
                                       var_name_to_prev_var_name={
                                           "restore_{}".format(emb_name):
                                           "save_{}".format(emb_name)
                                       })
            self.evaluate(restore_op)
            self.assertAllEqual(emb, val_list)
        def _model_fn(features, labels, mode, params):
            ids = features['ids']
            embeddings = de.get_variable(emb_name,
                                         dtypes.int64,
                                         dtypes.float32,
                                         devices=devices,
                                         initializer=0.0)
            emb = de.embedding_lookup(embeddings, ids, name="lookup")
            emb.graph.add_to_collection(
                de.GraphKeys.DYNAMIC_EMBEDDING_VARIABLES, embeddings)
            vars_to_warm_start = [embeddings]
            if use_regex:
                vars_to_warm_start = [".*t300.*"]

            warm_start_hook = de.WarmStartHook(
                ckpt_to_initialize_from=ckpt_prefix,
                vars_to_warm_start=vars_to_warm_start)
            return tf.estimator.EstimatorSpec(
                mode=tf.estimator.ModeKeys.PREDICT,
                predictions=emb,
                prediction_hooks=[warm_start_hook])
    def test_embedding_lookup_sparse_with_initializer(self):
        id = 0
        embed_dim = 8
        elements_num = 262144
        for initializer, target_mean, target_stddev in [
            (init_ops.random_normal_initializer(0.0, 0.001), 0.0, 0.001),
            (init_ops.truncated_normal_initializer(0.0, 0.001), 0.0, 0.00088),
            (keras_init_ops.RandomNormalV2(mean=0.0,
                                           stddev=0.001), 0.0, 0.001),
        ]:
            with self.session(config=default_config,
                              use_gpu=test_util.is_gpu_available()):
                id += 1
                embedding_weights = de.get_variable(
                    "emb-init-bugfix-" + str(id),
                    key_dtype=dtypes.int64,
                    value_dtype=dtypes.float32,
                    devices=_get_devices() * 3,
                    initializer=initializer,
                    dim=embed_dim,
                )

                ids = np.random.randint(
                    -0x7FFFFFFFFFFFFFFF,
                    0x7FFFFFFFFFFFFFFF,
                    elements_num,
                    dtype=np.int64,
                )
                ids = np.unique(ids)
                ids = constant_op.constant(ids, dtypes.int64)
                vals_op = de.embedding_lookup(embedding_weights, ids,
                                              None).eval()

                mean = self.evaluate(math_ops.reduce_mean(vals_op))
                stddev = self.evaluate(math_ops.reduce_std(vals_op))
                rtol = 2e-5
                atol = rtol
                self.assertTrue(not (list(vals_op[0]) == list(vals_op[1])))
                self.assertAllClose(target_mean, mean, rtol, atol)
                self.assertAllClose(target_stddev, stddev, rtol, atol)
def create_slots(primary, init, slot_name, op_name):
  """Helper function for creating a slot variable for statefull optimizers."""
  # lwk @primary是一个de.TrainableWrapper
  # lwk 所以@params_var_是de.Variable,params_ids_是特征ID
  params_var_, params_ids_ = primary.params, primary.ids

  scope_store = variable_scope._get_default_variable_store()
  full_name = params_var_.name + "/" + op_name + "/" + slot_name
  if full_name not in scope_store._vars:
    with ops.colocate_with(primary, ignore_existing=True):
      # lwk de.Variable的slot variable也是一个de.Variable,也就是是一个哈希表
      slot_variable_ = de.Variable(
          name=full_name,
          key_dtype=params_var_.key_dtype,
          value_dtype=params_var_.value_dtype,
          dim=params_var_.dim, # 维度一样
          devices=params_var_.devices,
          partitioner=params_var_.partition_fn,
          initializer=init,
          trainable=False,
          checkpoint=params_var_.checkpoint,
      )

    scope_store._vars[full_name] = slot_variable_

  # lwk 所以动态哈希表的原理,本质上就是将lookup的结果包装成一个variable
  # lwk 这个variable包含了哈希表和ids,优化器也使用这些信息,来构造一个对应的slot variable
  slot_trainable = None
  _, slot_trainable = de.embedding_lookup(
      params=scope_store._vars[full_name],
      ids=params_ids_,
      name=slot_name,
      return_trainable=True,
  )

  return slot_trainable
    def common_minimize_trainable(self, base_opt, test_opt, name):
        base_opt = de.DynamicEmbeddingOptimizer(base_opt)
        test_opt = de.DynamicEmbeddingOptimizer(test_opt)
        id = 0
        for (
                num_shards,
                k_dtype,
                d_dtype,
                initial_mode,
                dim,
                run_step,
        ) in itertools.product(
            [1, 2],
            [
                dtypes.int64,
            ],
            [
                dtypes.float32,
            ],
            [
                "constant",
            ],
            [1, 10],
            [10],
        ):
            id += 1
            with self.session(use_gpu=test_util.is_gpu_available(),
                              config=default_config) as sess:
                # common define
                raw_init_ids = [0, 1]
                raw_init_vals = np.random.rand(2, dim)
                raw_ids = [
                    0,
                ]
                x = constant_op.constant(np.random.rand(dim, len(raw_ids)),
                                         dtype=d_dtype)

                # base graph
                base_var = resource_variable_ops.ResourceVariable(
                    raw_init_vals, dtype=d_dtype)
                ids = constant_op.constant(raw_ids, dtype=k_dtype)
                pred0 = math_ops.matmul(
                    embedding_ops.embedding_lookup([base_var], ids), x)
                loss0 = pred0 * pred0
                base_opt_op = base_opt.minimize(loss0)

                # test graph
                embeddings = de.get_variable(
                    "t2020-" + name + str(id),
                    key_dtype=k_dtype,
                    value_dtype=d_dtype,
                    devices=_get_devices() * num_shards,
                    initializer=1.0,
                    dim=dim,
                )
                self.device_check(embeddings)
                init_ids = constant_op.constant(raw_init_ids, dtype=k_dtype)
                init_vals = constant_op.constant(raw_init_vals, dtype=d_dtype)
                init_op = embeddings.upsert(init_ids, init_vals)
                self.evaluate(init_op)

                test_var, trainable = de.embedding_lookup(
                    [embeddings], ids, return_trainable=True)
                pred1 = math_ops.matmul(test_var, x)
                loss1 = pred1 * pred1

                test_opt_op = test_opt.minimize(loss1, var_list=[trainable])

                self.evaluate(variables.global_variables_initializer())

                for _ in range(run_step):
                    sess.run(base_opt_op)

                # Fetch params to validate initial values
                self.assertAllCloseAccordingToType(raw_init_vals[raw_ids],
                                                   self.evaluate(test_var))
                # Run `run_step` step of sgd
                for _ in range(run_step):
                    sess.run(test_opt_op)

                table_var = embeddings.lookup(ids)
                # Validate updated params
                self.assertAllCloseAccordingToType(
                    self.evaluate(base_var)[raw_ids],
                    self.evaluate(table_var),
                    msg="Cond:{},{},{},{},{},{}".format(
                        num_shards, k_dtype, d_dtype, initial_mode, dim,
                        run_step),
                )
    def test_traing_save_restore(self):
        opt = de.DynamicEmbeddingOptimizer(adam.AdamOptimizer(0.3))
        id = 0
        if test_util.is_gpu_available():
            dim_list = [1, 2, 4, 8, 10, 16, 32, 64, 100, 256, 500]
        else:
            dim_list = [10]
        for key_dtype, value_dtype, dim, step in itertools.product(
            [dtypes.int64],
            [dtypes.float32],
                dim_list,
            [10],
        ):
            id += 1
            save_dir = os.path.join(self.get_temp_dir(), "save_restore")
            save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")

            ids = script_ops.py_func(_create_dynamic_shape_tensor(),
                                     inp=[],
                                     Tout=key_dtype,
                                     stateful=True)

            params = de.get_variable(
                name="params-test-0915-" + str(id),
                key_dtype=key_dtype,
                value_dtype=value_dtype,
                initializer=init_ops.random_normal_initializer(0.0, 0.01),
                dim=dim,
            )
            _, var0 = de.embedding_lookup(params, ids, return_trainable=True)

            def loss():
                return var0 * var0

            params_keys, params_vals = params.export()
            mini = opt.minimize(loss, var_list=[var0])
            opt_slots = [opt.get_slot(var0, _s) for _s in opt.get_slot_names()]
            _saver = saver.Saver([params] + [_s.params for _s in opt_slots])

            with self.session(config=default_config,
                              use_gpu=test_util.is_gpu_available()) as sess:
                self.evaluate(variables.global_variables_initializer())
                for _i in range(step):
                    self.evaluate([mini])
                size_before_saved = self.evaluate(params.size())
                np_params_keys_before_saved = self.evaluate(params_keys)
                np_params_vals_before_saved = self.evaluate(params_vals)
                opt_slots_kv_pairs = [_s.params.export() for _s in opt_slots]
                np_slots_kv_pairs_before_saved = [
                    self.evaluate(_kv) for _kv in opt_slots_kv_pairs
                ]
                _saver.save(sess, save_path)

            with self.session(config=default_config,
                              use_gpu=test_util.is_gpu_available()) as sess:
                self.evaluate(variables.global_variables_initializer())
                self.assertAllEqual(0, self.evaluate(params.size()))

                _saver.restore(sess, save_path)
                params_keys_restored, params_vals_restored = params.export()
                size_after_restored = self.evaluate(params.size())
                np_params_keys_after_restored = self.evaluate(
                    params_keys_restored)
                np_params_vals_after_restored = self.evaluate(
                    params_vals_restored)

                opt_slots_kv_pairs_restored = [
                    _s.params.export() for _s in opt_slots
                ]
                np_slots_kv_pairs_after_restored = [
                    self.evaluate(_kv) for _kv in opt_slots_kv_pairs_restored
                ]
                self.assertAllEqual(size_before_saved, size_after_restored)
                self.assertAllEqual(
                    np.sort(np_params_keys_before_saved),
                    np.sort(np_params_keys_after_restored),
                )
                self.assertAllEqual(
                    np.sort(np_params_vals_before_saved, axis=0),
                    np.sort(np_params_vals_after_restored, axis=0),
                )
                for pairs_before, pairs_after in zip(
                        np_slots_kv_pairs_before_saved,
                        np_slots_kv_pairs_after_restored):
                    self.assertAllEqual(
                        np.sort(pairs_before[0], axis=0),
                        np.sort(pairs_after[0], axis=0),
                    )
                    self.assertAllEqual(
                        np.sort(pairs_before[1], axis=0),
                        np.sort(pairs_after[1], axis=0),
                    )
                if test_util.is_gpu_available():
                    self.assertTrue(
                        "GPU" in params.tables[0].resource_handle.device)
예제 #27
0
def model_fn(features, labels, mode, params):
    #logging.info('mode: %s, labels: %s, params: %s, features: %s', mode, labels, params, features)
    if params["args"].get("addon_embedding"):
        import tensorflow_recommenders_addons as tfra
        import tensorflow_recommenders_addons.dynamic_embedding as dynamic_embedding
    else:
        import tensorflow.dynamic_embedding as dynamic_embedding

    features.update(labels)
    logging.info("------ build hyper parameters -------")
    embedding_size = params["parameters"]["embedding_size"]
    learning_rate = params["parameters"]["learning_rate"]
    use_bn = params["parameters"]["use_bn"]

    feat = params['features']
    sparse_feat_list = list(set(feat["sparse"]) -
                            set(SPARSE_MASK)) if 'sparse' in feat else []
    sparse_seq_feat_list = list(
        set(feat["sparse_seq"]) -
        set(SPARSE_SEQ_MASK)) if 'sparse_seq' in feat else []
    sparse_seq_feat_list = []
    #dense_feat_list = list(set(feat["dense"]) - set(DENSE_MASK)) if 'dense' in feat else []
    # hashtable v1/v2 image均无法同时关bn和mask dense_feat
    dense_feat_list = []
    dense_seq_feat_list = list(
        set(feat["dense_seq"]) -
        set(DENSE_SEQ_MASK)) if 'dense_seq' in feat else []

    sparse_feat_num = len(sparse_feat_list)
    sparse_seq_num = len(sparse_seq_feat_list)
    dense_feat_num = len(dense_feat_list)
    dense_seq_feat_num = len(dense_seq_feat_list)

    all_features = (sparse_feat_list + sparse_seq_feat_list + dense_feat_list +
                    dense_seq_feat_list)

    batch_size = tf.shape(features[goods_id_feat])[0]
    logging.info("------ show batch_size: {} -------".format(batch_size))

    level_0_feats = list(
        set(params.get('level_0_feat_list')) & set(all_features))
    logging.info('level_0_feats: {}'.format(level_0_feats))
    new_features = dict()
    if params["args"].get("level_flag") and params["args"].get(
            "job_type") == "export":
        for feature_name in features:
            if feature_name in level_0_feats:
                new_features[feature_name] = tf.reshape(
                    tf.tile(tf.reshape(features[feature_name], [1, -1]),
                            [batch_size, 1]), [batch_size, -1])
            else:
                new_features[feature_name] = features[feature_name]
        features = new_features

    l2_reg = params["parameters"]["l2_reg"]
    is_training = True if mode == tf.estimator.ModeKeys.TRAIN else False
    has_label = True if 'is_imp' in features else False
    logging.info("is_training: {}, has_label: {}, features: {}".format(
        is_training, has_label, features))

    logging.info("------ build embedding -------")
    # def partition_fn(keys, shard_num=params["parameters"]["ps_nums"]):
    #     return tf.cast(keys % shard_num, dtype=tf.int32)
    if is_training:
        devices_info = [
            "/job:ps/replica:0/task:{}/CPU:0".format(i)
            for i in range(params["parameters"]["ps_num"])
        ]
        initializer = tf.compat.v1.truncated_normal_initializer(0.0, 1e-2)
    else:
        devices_info = [
            "/job:localhost/replica:0/task:{}/CPU:0".format(0)
            for i in range(params["parameters"]["ps_num"])
        ]
        initializer = tf.compat.v1.zeros_initializer()
    logging.info("------ dynamic_embedding devices_info is {}-------".format(
        devices_info))
    if mode == tf.estimator.ModeKeys.PREDICT:
        dynamic_embedding.enable_inference_mode()

    deep_dynamic_variables = dynamic_embedding.get_variable(
        name="deep_dynamic_embeddings",
        devices=devices_info,
        initializer=initializer,
        # partitioner=partition_fn,
        dim=embedding_size,
        trainable=is_training,
        #init_size=INIT_SIZE
    )

    sparse_feat = None
    sparse_unique_ids = None
    if sparse_feat_num > 0:
        logging.info("------ build sparse feature -------")
        id_list = sorted(sparse_feat_list)
        ft_sparse_idx = tf.concat(
            [tf.reshape(features[str(i)], [-1, 1]) for i in id_list], axis=1)
        sparse_unique_ids, sparse_unique_idx = tf.unique(
            tf.reshape(ft_sparse_idx, [-1]))

        sparse_weights = dynamic_embedding.embedding_lookup(
            params=deep_dynamic_variables,
            ids=sparse_unique_ids,
            name="deep_sparse_weights")
        if params["args"].get("zero_padding"):
            sparse_weights = tf.reshape(sparse_weights, [-1, embedding_size])
            sparse_weights = tf.where(
                tf.not_equal(
                    tf.expand_dims(sparse_unique_ids, axis=1),
                    tf.zeros_like(tf.expand_dims(sparse_unique_ids, axis=1))),
                sparse_weights, tf.zeros_like(sparse_weights))

        sparse_weights = tf.gather(sparse_weights, sparse_unique_idx)
        sparse_feat = tf.reshape(
            sparse_weights,
            shape=[batch_size, sparse_feat_num * embedding_size])

    sparse_seq_feat = None
    sparse_seq_unique_ids = None
    if sparse_seq_num > 0:
        logging.info("---- build sparse seq feature ---")
        if params["args"].get("merge_sparse_seq"):
            sparse_seq_name_list = sorted(
                sparse_seq_feat_list)  #[B, s1], [B, s2], ... [B, sn]
            ft_sparse_seq_ids = tf.concat(
                [
                    tf.reshape(features[str(i)], [batch_size, -1])
                    for i in sparse_seq_name_list
                ],
                axis=1)  #[B, [s1, s2, ...sn]] => [B, per_seq_len*seq_num]

            sparse_seq_unique_ids, sparse_seq_unique_idx = tf.unique(
                tf.reshape(ft_sparse_seq_ids,
                           [-1]))  #[u], [B*per_seq_len*seq_num]

            sparse_seq_weights = dynamic_embedding.embedding_lookup(
                params=deep_dynamic_variables,
                ids=sparse_seq_unique_ids,
                name="deep_sparse_seq_weights")  #[u, e]

            deep_embed_seq = tf.where(
                tf.not_equal(
                    tf.expand_dims(sparse_seq_unique_ids, axis=1),
                    tf.zeros_like(tf.expand_dims(sparse_seq_unique_ids,
                                                 axis=1))), sparse_seq_weights,
                tf.zeros_like(sparse_seq_weights))  #[u, e]

            deep_embedding_seq = tf.reshape(
                tf.gather(deep_embed_seq,
                          sparse_seq_unique_idx),  #[B*per_seq_len*seq_num, e]
                shape=[batch_size, sparse_seq_num, -1,
                       embedding_size])  #[B, seq_num, per_seq_len, e]
            if params["parameters"]["combiner"] == "sum":
                tmp_feat = tf.reduce_sum(deep_embedding_seq, axis=2)
            else:
                tmp_feat = tf.reduce_mean(deep_embedding_seq, axis=2)
            sparse_seq_feat = tf.reshape(
                tmp_feat,
                [batch_size, sparse_seq_num * embedding_size])  #[B, seq_num*e]
        else:
            sparse_seq_feats = []
            sparse_ids = []
            for sparse_seq_name in sparse_seq_feat_list:
                sp_ids = features[sparse_seq_name]
                if params["args"].get("zero_padding2"):
                    sparse_seq_unique_ids, sparse_seq_unique_idx, _ = tf.unique_with_counts(
                        tf.reshape(sp_ids, [-1]))

                    deep_sparse_seq_weights = tf.reshape(
                        dynamic_embedding.embedding_lookup(
                            params=deep_dynamic_variables,
                            ids=sparse_seq_unique_ids,
                            name="deep_sparse_weights_{}".format(
                                sparse_seq_name)), [-1, embedding_size])

                    deep_embed_seq = tf.where(
                        tf.not_equal(
                            tf.expand_dims(sparse_seq_unique_ids, axis=1),
                            tf.zeros_like(
                                tf.expand_dims(sparse_seq_unique_ids,
                                               axis=1))),
                        deep_sparse_seq_weights,
                        tf.zeros_like(deep_sparse_seq_weights))

                    deep_embedding_seq = tf.reshape(
                        tf.gather(deep_embed_seq, sparse_seq_unique_idx),
                        shape=[batch_size, -1, embedding_size])

                    if params["parameters"]["combiner"] == "sum":
                        tmp_feat = tf.reduce_sum(deep_embedding_seq, axis=1)
                    else:
                        tmp_feat = tf.reduce_mean(deep_embedding_seq, axis=1)
                    sparse_ids.append(sparse_seq_unique_ids)
                    sparse_seq_feats.append(
                        tf.reshape(tmp_feat, [batch_size, embedding_size]))
                else:
                    tmp_feat = dynamic_embedding.safe_embedding_lookup_sparse(
                        embedding_weights=deep_dynamic_variables,
                        sparse_ids=sp_ids,
                        combiner=params["parameters"]["combiner"],
                        name="safe_embedding_lookup_sparse")
                    temp_uni_id, _, _ = tf.unique_with_counts(
                        tf.reshape(sp_ids.values, [-1]))
                    sparse_ids.append(temp_uni_id)
                    sparse_seq_feats.append(
                        tf.reshape(tmp_feat, [batch_size, embedding_size]))

            sparse_seq_feat = tf.concat(sparse_seq_feats, axis=1)
            sparse_seq_unique_ids, _ = tf.unique(tf.concat(sparse_ids, axis=0))

    dense_feat = None
    if dense_feat_num > 0:
        logging.info("------ build dense feature -------")
        den_id_list = sorted(dense_feat_list)
        dense_feat_base = tf.concat(
            [tf.reshape(features[str(i)], [-1, 1]) for i in den_id_list],
            axis=1)

        #deep_dense_w1 = tf.compat.v1.get_variable('deep_dense_w1',
        #                                          tf.TensorShape([dense_feat_num]),
        #                                          initializer=tf.compat.v1.truncated_normal_initializer(
        #                                              2.0 / math.sqrt(dense_feat_num)),
        #                                          dtype=tf.float32)
        #deep_dense_w2 = tf.compat.v1.get_variable('deep_dense_w2',
        #                                          tf.TensorShape([dense_feat_num]),
        #                                          initializer=tf.compat.v1.truncated_normal_initializer(
        #                                              2.0 / math.sqrt(dense_feat_num)),
        #                                          dtype=tf.float32)

        #w1 = tf.tile(tf.expand_dims(deep_dense_w1, axis=0), [tf.shape(dense_feat_base)[0], 1])
        #dense_input_1 = tf.multiply(dense_feat_base, w1)
        #dense_feat = dense_input_1
        dense_feat = dense_feat_base

    dense_seq_feat = None
    if dense_seq_feat_num > 0:
        logging.info("------ build dense seq feature -------")
        den_seq_id_list = sorted(dense_seq_feat_list)
        dense_seq_feat = tf.concat([
            tf.reshape(features[str(i[0])], [-1, i[1]])
            for i in den_seq_id_list
        ],
                                   axis=1)

    logging.info("------ join all feature -------")
    fc_inputs = tf.concat([
        x for x in [sparse_feat, sparse_seq_feat, dense_feat, dense_seq_feat]
        if x is not None
    ],
                          axis=1)

    logging.info("---- tracy debug input is ----")
    logging.info(sparse_feat)
    logging.info(sparse_seq_feat)
    logging.info(dense_feat)
    logging.info(dense_seq_feat)
    logging.info(fc_inputs)

    logging.info("------ join fc -------")
    for idx, units in enumerate(params["parameters"]["hidden_units"]):
        fc_inputs = fully_connected_with_bn_ahead(
            inputs=fc_inputs,
            num_outputs=units,
            l2_reg=l2_reg,
            scope="out_mlp_{}".format(idx),
            activation_fn=tf.nn.relu,
            train_phase=is_training,
            use_bn=use_bn)
    y_deep_ctr = fully_connected_with_bn_ahead(inputs=fc_inputs,
                                               num_outputs=1,
                                               activation_fn=tf.identity,
                                               l2_reg=l2_reg,
                                               scope="ctr_mlp",
                                               train_phase=is_training,
                                               use_bn=use_bn)

    logging.info("------ build ctr out -------")
    sample_rate = params["args"]["sample_rate"]
    logit = tf.reshape(y_deep_ctr, shape=[-1], name="logit")
    sample_logit = get_sample_logits(logit, sample_rate)
    pred_ctr = tf.nn.sigmoid(logit, name="pred_ctr")
    sample_pred_ctr = tf.nn.sigmoid(sample_logit, name="sample_pred_ctr")

    logging.info("------ build predictions -------")
    preds = {
        'p_ctr': tf.reshape(pred_ctr, shape=[-1, 1]),
    }

    logging.info("---- deep_dynamic_variables.size ----")
    logging.info(deep_dynamic_variables.size())
    size = tf.identity(deep_dynamic_variables.size(), name="size")

    label_col = "is_clk"
    if params["args"].get("set_train_labels"):
        label_col = params["args"]["set_train_labels"]["1"]

    logging.info(
        "------ build labels, label_col: {} -------".format(label_col))
    if has_label:
        labels_ctr = tf.reshape(features["is_clk"],
                                shape=[-1],
                                name="labels_ctr")

    if mode == tf.estimator.ModeKeys.PREDICT:
        logging.info("---- build tf-serving predict ----")
        pred_cvr = tf.fill(tf.shape(pred_ctr), 1.0)
        preds.update({
            'labels_cart': tf.reshape(pred_cvr, shape=[-1, 1]),
            'p_car': tf.reshape(features["dense_1608"], shape=[-1, 1]),
            'labels_cvr': tf.reshape(pred_cvr, shape=[-1, 1]),
            'p_cvr': tf.reshape(pred_cvr, shape=[-1, 1]),
        })
        if 'logid' in features:
            preds.update(
                {'logid': tf.reshape(features["logid"], shape=[-1, 1])})
        if has_label:
            logging.info("------ build offline label -------")
            preds["labels_ctr"] = tf.reshape(labels_ctr, shape=[-1, 1])
        export_outputs = {
            "predict_export_outputs":
            tf.estimator.export.PredictOutput(outputs=preds)
        }
        return tf.estimator.EstimatorSpec(mode,
                                          predictions=preds,
                                          export_outputs=export_outputs)

    logging.info("----all vars:-----" + str(tf.compat.v1.global_variables()))
    for var in tf.compat.v1.trainable_variables():
        logging.info("----trainable------" + str(var))

    logging.info("------ build metric -------")
    #loss = tf.reduce_mean(
    #    tf.compat.v1.losses.log_loss(labels=labels_ctr, predictions=sample_pred_ctr),
    #    name="loss")
    loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        labels=labels_ctr, logits=sample_logit),
                          name="loss")
    ctr_auc = tf.compat.v1.metrics.auc(labels=labels_ctr,
                                       predictions=sample_pred_ctr,
                                       name="ctr_auc")

    label_ctr_avg = tf.reduce_mean(labels_ctr, name="label_ctr_avg")
    real_pred_ctr_avg = tf.reduce_mean(pred_ctr, name="real_pred_ctr_avg")
    sample_pred_ctr_avg = tf.reduce_mean(sample_pred_ctr, name="pred_ctr_avg")
    sample_pred_bias_avg = tf.add(sample_pred_ctr_avg,
                                  tf.negative(label_ctr_avg),
                                  name="pred_bias_avg")
    tf.compat.v1.summary.histogram('labels_ctr', labels_ctr)
    tf.compat.v1.summary.histogram('pred_ctr', sample_pred_ctr)
    tf.compat.v1.summary.histogram('real_pred_ctr', pred_ctr)

    tf.compat.v1.summary.scalar('label_ctr_avg', label_ctr_avg)
    tf.compat.v1.summary.scalar('pred_ctr_avg', sample_pred_ctr_avg)
    tf.compat.v1.summary.scalar('real_pred_ctr_avg', real_pred_ctr_avg)
    tf.compat.v1.summary.scalar('pred_bias_avg', sample_pred_bias_avg)
    tf.compat.v1.summary.scalar('loss', loss)
    tf.compat.v1.summary.scalar('ctr_auc', ctr_auc[1])

    logging.info("------ compute l2 reg -------")
    if params["parameters"]["use_l2"]:
        all_unique_ids, _ = tf.unique(
            tf.concat([
                x for x in [sparse_unique_ids, sparse_seq_unique_ids]
                if x is not None
            ],
                      axis=0))

        all_unique_ids_w = dynamic_embedding.embedding_lookup(
            deep_dynamic_variables,
            all_unique_ids,
            name="unique_ids_weights",
            return_trainable=False)
        embed_loss = l2_reg * tf.nn.l2_loss(
            tf.reshape(all_unique_ids_w,
                       shape=[-1, embedding_size])) + tf.reduce_sum(
                           tf.compat.v1.get_collection(
                               tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES))

        tf.compat.v1.summary.scalar('embed_loss', embed_loss)
        loss = loss + embed_loss

    loss = tf.identity(loss, name="total_loss")
    tf.compat.v1.summary.scalar('total_loss', loss)

    if mode == tf.estimator.ModeKeys.EVAL:
        logging.info("------ EVAL -------")
        eval_metric_ops = {
            "ctr_auc_eval": ctr_auc,
        }
        if has_label:
            logging.info("------ build offline label -------")
            preds["labels_ctr"] = tf.reshape(labels_ctr, shape=[-1, 1])
        export_outputs = {
            "predict_export_outputs":
            tf.estimator.export.PredictOutput(outputs=preds)
        }
        return tf.estimator.EstimatorSpec(mode=mode,
                                          loss=loss,
                                          eval_metric_ops=eval_metric_ops,
                                          export_outputs=export_outputs)

    logging.info("---- Learning rate ----")
    lr = get_learning_rate(params["parameters"]["learning_rate"],
                           params["parameters"]["use_decay"])

    if mode == tf.estimator.ModeKeys.TRAIN:
        global_step = tf.compat.v1.train.get_global_step()
        logging.info("------ TRAIN -------")
        optimizer_type = params["parameters"].get('optimizer', 'Adam')
        if optimizer_type == 'Sgd':
            optimizer = tf.compat.v1.train.GradientDescentOptimizer(
                learning_rate=lr)
        elif optimizer_type == 'Adagrad':
            optimizer = tf.compat.v1.train.AdagradOptimizer(learning_rate=lr)
        elif optimizer_type == 'Rmsprop':
            optimizer = tf.compat.v1.train.RMSPropOptimizer(learning_rate=lr)
        elif optimizer_type == 'Ftrl':
            optimizer = tf.compat.v1.train.FtrlOptimizer(learning_rate=lr)
        elif optimizer_type == 'Momentum':
            optimizer = tf.compat.v1.train.MomentumOptimizer(learning_rate=lr,
                                                             momentum=0.9)
        else:
            optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=lr,
                                                         beta1=0.9,
                                                         beta2=0.999,
                                                         epsilon=1e-8)

        if params["args"].get("addon_embedding"):
            optimizer = dynamic_embedding.DynamicEmbeddingOptimizer(optimizer)

        train_op = optimizer.minimize(loss, global_step=global_step)
        # fix tf2 batch_normalization bug
        update_ops = tf.compat.v1.get_collection(
            tf.compat.v1.GraphKeys.UPDATE_OPS)
        logging.info('train ops: {}, update ops: {}'.format(
            str(train_op), str(update_ops)))
        train_op = tf.group([train_op, update_ops])
        return tf.estimator.EstimatorSpec(mode=mode,
                                          predictions=preds,
                                          loss=loss,
                                          train_op=train_op)