def test_check_ops_number(self):
   self.assertTrue(de.get_model_mode() == "train")
   de.enable_inference_mode()
   self.assertTrue(de.get_model_mode() == "inference")
   de.enable_train_mode()
   self.assertTrue(de.get_model_mode() == "train")
   for fn, assign_num, read_num in [(de.enable_train_mode, 1, 2),
                                    (de.enable_inference_mode, 0, 1)]:
     fn()
     embeddings = de.get_variable('ModeModeTest' + str(assign_num),
                                  key_dtype=dtypes.int64,
                                  value_dtype=dtypes.float32,
                                  devices=_get_devices(),
                                  initializer=1.,
                                  dim=8)
     ids = constant_op.constant([0, 1, 2, 3, 4], dtype=dtypes.int64)
     test_var, trainable = de.embedding_lookup([embeddings],
                                               ids,
                                               return_trainable=True)
     _ = math_ops.add(test_var, 1)
     op_list = ops.get_default_graph().get_operations()
     op_list_assign = [
         op.name for op in op_list if "AssignBeforeReadVariable" in op.name
     ]
     op_list_read = [op.name for op in op_list if "ReadVariableOp" in op.name]
     self.assertTrue(len(op_list_assign) == assign_num)
     self.assertTrue(len(op_list_read) == read_num)
     de.enable_train_mode()
     ops.reset_default_graph()
  def test_inference_numberic_correctness(self):
    train_pred = None
    infer_pred = None
    dim = 8
    initializer = init_ops.random_normal_initializer(0.0, 0.001)
    raw_init_vals = np.random.rand(100, dim)

    for fn in [de.enable_train_mode, de.enable_inference_mode]:
      with ops.Graph().as_default():
        fn()

        init_ids = constant_op.constant(list(range(100)), dtype=dtypes.int64)
        init_vals = constant_op.constant(raw_init_vals, dtype=dtypes.float32)
        with variable_scope.variable_scope("modelmode",
                                           reuse=variable_scope.AUTO_REUSE):
          embeddings = de.get_variable('ModelModeTest-numberic',
                                       key_dtype=dtypes.int64,
                                       value_dtype=dtypes.float32,
                                       devices=_get_devices() * 2,
                                       initializer=initializer,
                                       dim=dim)

          w = variables.Variable(1.0, name="w")
          _ = training_util.create_global_step()
        init_op = embeddings.upsert(init_ids, init_vals)

        ids = constant_op.constant([0, 1, 2, 3, 4], dtype=dtypes.int64)
        test_var, trainable = de.embedding_lookup([embeddings],
                                                  ids,
                                                  return_trainable=True)
        pred = math_ops.add(test_var, 1) * w
        loss = pred * pred
        opt = de.DynamicEmbeddingOptimizer(adagrad.AdagradOptimizer(0.1))
        opt.minimize(loss)

        with monitored_session.MonitoredTrainingSession(
            is_chief=True, config=default_config) as sess:
          if de.get_model_mode() == de.ModelMode.TRAIN:
            sess.run(init_op)
            train_pred = sess.run(pred)
          elif de.get_model_mode() == de.ModelMode.INFERENCE:
            sess.run(init_op)
            infer_pred = sess.run(pred)
      de.enable_train_mode()
      ops.reset_default_graph()
    self.assertAllEqual(train_pred, infer_pred)
  def common_minimize_trainable(self, base_opt, test_opt, name):
    de.enable_train_mode()
    base_opt = de.DynamicEmbeddingOptimizer(base_opt)
    test_opt = de.DynamicEmbeddingOptimizer(test_opt)
    id = 0
    for (
        num_shards,
        k_dtype,
        d_dtype,
        initial_mode,
        dim,
        run_step,
    ) in itertools.product(
        [3],
        [dtypes.int64],
        [
            dtypes.float32,
        ],
        [
            "constant",
        ],
        [1, 10],
        [10],
    ):
      with ops.Graph().as_default():
        id += 1
        raw_init_ids = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
        raw_init_vals = [
            [
                x,
            ] * dim
            for x in [0.0, 0.1, 0.3, 0.8, 0.16, 0.25, 0.36, 0.49, 0.64, 0.81]
        ]
        raw_ids = constant_op.constant([1, 3, 3, 9], dtype=k_dtype)
        sp_ids = sparse_tensor.SparseTensor(
            indices=[
                [0, 0],
                [0, 1],
                [1, 0],
                [2, 1],
            ],
            values=raw_ids,
            dense_shape=[3, 2],
        )
        x = constant_op.constant([[_x * dim] for _x in [[0.4], [0.5], [0.6]]],
                                 dtype=d_dtype)
        x = array_ops.reshape(x, shape=(3 * dim, 1))
        # base var prepare
        base_var = variables.Variable(
            np.array(raw_init_vals).reshape([len(raw_init_ids), dim]),
            dtype=d_dtype,
            shape=[len(raw_init_ids), dim],
        )

        # test var prepare
        embeddings = de.get_variable(
            "t1030-" + name + str(id),
            key_dtype=k_dtype,
            value_dtype=d_dtype,
            devices=_get_devices() * num_shards,
            initializer=1.0,
            dim=dim,
        )

        init_ids = constant_op.constant(raw_init_ids, dtype=k_dtype)
        init_vals = constant_op.constant(raw_init_vals, dtype=d_dtype)
        init_op = embeddings.upsert(init_ids, init_vals)

        # base branch
        base_embedding = embedding_ops.embedding_lookup_sparse(base_var,
                                                               sp_ids,
                                                               None,
                                                               combiner="sum")
        base_embedding = array_ops.reshape(base_embedding, shape=[1, 3 * dim])
        pred0 = math_ops.matmul(base_embedding, x)
        loss0 = pred0 * pred0

        base_opt_op = base_opt.minimize(loss0, var_list=[base_var])

        # test branch
        test_var, trainable = de.embedding_lookup_sparse(
            embeddings,
            sp_ids,
            sp_weights=None,
            combiner="sum",
            return_trainable=True,
        )

        pred1 = math_ops.matmul(array_ops.reshape(test_var, shape=[1, 3 * dim]),
                                x)
        loss1 = pred1 * pred1

        gstep = training_util.create_global_step()
        test_opt_op = test_opt.minimize(loss1,
                                        var_list=[trainable],
                                        global_step=gstep)

        table_var = array_ops.reshape(embeddings.lookup(init_ids),
                                      shape=[10, dim])

        with monitored_session.MonitoredTrainingSession(
            is_chief=True, config=default_config) as sess:
          sess.run(init_op)
          self.assertAllCloseAccordingToType(
              np.array(raw_init_vals).reshape([len(raw_init_ids), dim]),
              sess.run(base_var),
          )

          # run base
          for _ in range(run_step):
            sess.run(base_opt_op)
            sess.run(test_opt_op)

          # Validate global_step
          self.assertEqual(run_step, sess.run(gstep))

          # Validate updated params
          self.assertAllCloseAccordingToType(
              sess.run(base_var),
              sess.run(table_var),
              msg="Cond:{},{},{},{},{}".format(num_shards, k_dtype, d_dtype,
                                               dim, run_step),
          )
          self.device_check(embeddings)
  def common_minimize_trainable_v2(self, base_opt, test_opt, name):
    de.enable_train_mode()
    base_opt = de.DynamicEmbeddingOptimizer(base_opt)
    test_opt = de.DynamicEmbeddingOptimizer(test_opt)
    id = 0
    for (
        num_shards,
        k_dtype,
        d_dtype,
        initial_mode,
        dim,
        run_step,
    ) in itertools.product(
        [1, 2],
        [
            dtypes.int64,
        ],
        [
            dtypes.float32,
        ],
        [
            "constant",
        ],
        [1, 10],
        [10],
    ):
      id += 1
      raw_init_ids = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
      raw_init_vals = [[
          x,
      ] * dim for x in [0.0, 0.1, 0.3, 0.8, 0.16, 0.25, 0.36, 0.49, 0.64, 0.81]]
      raw_ids = constant_op.constant([1, 3, 3, 9], dtype=k_dtype)
      sp_ids = sparse_tensor.SparseTensor(
          indices=[
              [0, 0],
              [0, 1],
              [1, 0],
              [2, 1],
          ],
          values=raw_ids,
          dense_shape=[3, 2],
      )
      x = constant_op.constant([[_x * dim] for _x in [[0.4], [0.5], [0.6]]],
                               dtype=d_dtype)
      x = array_ops.reshape(x, shape=(dim, -1))

      # # base graph
      def base_fn():
        embeddings = variables.Variable(
            np.array(raw_init_vals).reshape([len(raw_init_ids), dim]),
            dtype=d_dtype,
            shape=[len(raw_init_ids), dim],
        )

        def loss_fn(emb):
          embedding = embedding_ops.safe_embedding_lookup_sparse(emb,
                                                                 sp_ids,
                                                                 None,
                                                                 combiner="sum")
          pred0 = math_ops.matmul(embedding, x)
          return pred0 * pred0

        base_opt_op = base_opt.minimize(lambda: loss_fn(embeddings),
                                        [embeddings])
        self.evaluate(variables.global_variables_initializer())
        for _ in range(run_step):
          self.evaluate(base_opt_op)
        return embeddings

      base_opt_val = self.evaluate(base_fn())

      def test_fn():
        embeddings = de.get_variable(
            "s6030-v2-" + name + str(id),
            key_dtype=k_dtype,
            value_dtype=d_dtype,
            devices=_get_devices() * num_shards,
            initializer=1.0,
            dim=dim,
        )
        self.device_check(embeddings)

        init_ids = constant_op.constant(raw_init_ids, dtype=k_dtype)
        init_vals = constant_op.constant(raw_init_vals, dtype=d_dtype)
        self.evaluate(embeddings.upsert(init_ids, init_vals))
        trainables = []

        def var_fn():
          return trainables

        def loss_fn(emb, trainables):
          test_var, trainable = de.safe_embedding_lookup_sparse(
              emb,
              sp_ids,
              sparse_weights=None,
              combiner="sum",
              return_trainable=True,
          )

          pred = math_ops.matmul(test_var, x)
          trainables.clear()
          trainables.append(trainable)
          return pred * pred

        test_opt_op = test_opt.minimize(lambda: loss_fn(embeddings, trainables),
                                        var_fn)
        self.evaluate(variables.global_variables_initializer())
        for _ in range(run_step):
          self.evaluate(test_opt_op)
        return embeddings.lookup(init_ids)

      test_opt_val = test_fn()
      self.assertAllCloseAccordingToType(
          base_opt_val,
          test_opt_val,
          msg="Cond:{},{},{},{},{},{}".format(num_shards, k_dtype, d_dtype,
                                              initial_mode, dim, run_step),
      )
  def common_minimize_trainable(self, base_opt, test_opt, name):
    de.enable_train_mode()
    base_opt = de.DynamicEmbeddingOptimizer(base_opt)
    test_opt = de.DynamicEmbeddingOptimizer(test_opt)
    id = 0
    config = config_pb2.ConfigProto(
        allow_soft_placement=True,
        gpu_options=config_pb2.GPUOptions(allow_growth=True),
    )
    for (
        num_shards,
        k_dtype,
        d_dtype,
        initial_mode,
        dim,
        run_step,
    ) in itertools.product(
        [1, 2],
        [dtypes.int64],
        [
            dtypes.float32,
        ],
        [
            "constant",
        ],
        [1, 10],
        [10],
    ):
      with self.session(config=config, use_gpu=test_util.is_gpu_available()):
        id += 1
        raw_init_ids = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
        raw_init_vals = [
            [
                x,
            ] * dim
            for x in [0.0, 0.1, 0.3, 0.8, 0.16, 0.25, 0.36, 0.49, 0.64, 0.81]
        ]
        raw_ids = constant_op.constant([1, 3, 3, 9], dtype=k_dtype)
        sp_ids = sparse_tensor.SparseTensor(
            indices=[
                [0, 0],
                [0, 1],
                [1, 0],
                [2, 1],
            ],
            values=raw_ids,
            dense_shape=[3, 2],
        )
        x = constant_op.constant([[_x * dim] for _x in [[0.4], [0.5], [0.6]]],
                                 dtype=d_dtype)
        x = array_ops.reshape(x, shape=(3 * dim, 1))
        # base var prepare
        base_var = variables.Variable(
            np.array(raw_init_vals).reshape([len(raw_init_ids), dim]),
            dtype=d_dtype,
            shape=[len(raw_init_ids), dim],
        )
        base_embedding = embedding_ops.safe_embedding_lookup_sparse(
            base_var, sp_ids, None, combiner="sum")
        base_embedding = array_ops.reshape(base_embedding, shape=[1, 3 * dim])
        pred0 = math_ops.matmul(base_embedding, x)
        loss0 = pred0 * pred0

        base_opt_op = base_opt.minimize(loss0, var_list=[base_var])

        # test var prepare
        embeddings = de.get_variable(
            "s6030-" + name + str(id),
            key_dtype=k_dtype,
            value_dtype=d_dtype,
            devices=_get_devices() * num_shards,
            initializer=1.0,
            dim=dim,
        )
        self.device_check(embeddings)

        init_ids = constant_op.constant(raw_init_ids, dtype=k_dtype)
        init_vals = constant_op.constant(raw_init_vals, dtype=d_dtype)
        init_op = embeddings.upsert(init_ids, init_vals)
        self.evaluate(init_op)

        # test branch
        test_var, trainable = de.safe_embedding_lookup_sparse(
            embeddings,
            sp_ids,
            sparse_weights=None,
            combiner="sum",
            return_trainable=True,
        )

        pred1 = math_ops.matmul(array_ops.reshape(test_var, shape=[1, 3 * dim]),
                                x)
        loss1 = pred1 * pred1
        test_opt_op = test_opt.minimize(loss1, var_list=[trainable])

        self.evaluate(variables.global_variables_initializer())

        self.assertAllCloseAccordingToType(
            np.array(raw_init_vals).reshape([len(raw_init_ids), dim]),
            self.evaluate(base_var),
        )

        # run base
        for _ in range(run_step):
          self.evaluate(base_opt_op)

        # Run `run_step` step of sgd
        for _ in range(run_step):
          self.evaluate(test_opt_op)

        table_var = array_ops.reshape(embeddings.lookup(init_ids),
                                      shape=[10, dim])
        # Validate updated params
        self.assertAllCloseAccordingToType(
            self.evaluate(base_var),
            self.evaluate(table_var),
            msg="Cond:{},{},{},{},{}".format(num_shards, k_dtype, d_dtype, dim,
                                             run_step),
        )
  def common_minimize_trainable_v2(self, base_opt, test_opt, name):
    de.enable_train_mode()
    base_opt = de.DynamicEmbeddingOptimizer(base_opt)
    test_opt = de.DynamicEmbeddingOptimizer(test_opt)
    id = 0
    for (
        num_shards,
        k_dtype,
        d_dtype,
        initial_mode,
        dim,
        run_step,
    ) in itertools.product(
        [1, 2],
        [
            dtypes.int64,
        ],
        [
            dtypes.float32,
        ],
        [
            "constant",
        ],
        [1, 10],
        [10],
    ):
      id += 1
      # common define
      raw_init_ids = [0, 1]
      raw_init_vals = np.random.rand(2, dim)
      raw_ids = [
          0,
      ]

      # base graph
      def base_fn():
        embeddings = resource_variable_ops.ResourceVariable(raw_init_vals,
                                                            dtype=d_dtype)

        def loss_fn(emb):
          ids = constant_op.constant(raw_ids, dtype=k_dtype)
          pred = embedding_ops.embedding_lookup([emb], ids)
          return pred * pred

        base_opt_op = base_opt.minimize(lambda: loss_fn(embeddings),
                                        [embeddings])
        self.evaluate(variables.global_variables_initializer())
        for _ in range(run_step):
          self.evaluate(base_opt_op)
        return embeddings

      base_opt_val = self.evaluate(base_fn())

      def test_fn():
        embeddings = de.get_variable(
            "t2020-v2-" + name + str(id),
            key_dtype=k_dtype,
            value_dtype=d_dtype,
            devices=_get_devices() * num_shards,
            initializer=1.0,
            dim=dim,
        )
        self.device_check(embeddings)
        trainables = []
        init_ids = constant_op.constant(raw_init_ids, dtype=k_dtype)
        init_vals = constant_op.constant(raw_init_vals, dtype=d_dtype)
        self.evaluate(embeddings.upsert(init_ids, init_vals))

        def var_fn():
          return trainables

        def loss_fn(x, trainables):
          ids = constant_op.constant(raw_ids, dtype=k_dtype)
          pred, trainable = de.embedding_lookup([x], ids, return_trainable=True)
          trainables.clear()
          trainables.append(trainable)
          return pred * pred

        test_opt_op = test_opt.minimize(lambda: loss_fn(embeddings, trainables),
                                        var_fn)
        self.evaluate(variables.global_variables_initializer())
        for _ in range(run_step):
          self.evaluate(test_opt_op)
        return embeddings.lookup(init_ids)

      with ops.device(_get_devices()[0]):
        test_opt_val = self.evaluate(test_fn())
      self.assertAllCloseAccordingToType(
          base_opt_val,
          test_opt_val,
          msg="Cond:{},{},{},{},{},{}".format(num_shards, k_dtype, d_dtype,
                                              initial_mode, dim, run_step),
      )
  def common_minimize_trainable(self, base_opt, test_opt, name):
    de.enable_train_mode()
    base_opt = de.DynamicEmbeddingOptimizer(base_opt)
    test_opt = de.DynamicEmbeddingOptimizer(test_opt)
    id = 0
    for (
        num_shards,
        k_dtype,
        d_dtype,
        initial_mode,
        dim,
        run_step,
    ) in itertools.product(
        [1, 2],
        [
            dtypes.int64,
        ],
        [
            dtypes.float32,
        ],
        [
            "constant",
        ],
        [1, 10],
        [10],
    ):
      id += 1
      with self.session(use_gpu=test_util.is_gpu_available(),
                        config=default_config) as sess:
        # common define
        raw_init_ids = [0, 1]
        raw_init_vals = np.random.rand(2, dim)
        raw_ids = [
            0,
        ]
        x = constant_op.constant(np.random.rand(dim, len(raw_ids)),
                                 dtype=d_dtype)

        # base graph
        base_var = resource_variable_ops.ResourceVariable(raw_init_vals,
                                                          dtype=d_dtype)
        ids = constant_op.constant(raw_ids, dtype=k_dtype)
        pred0 = math_ops.matmul(embedding_ops.embedding_lookup([base_var], ids),
                                x)
        loss0 = pred0 * pred0
        base_opt_op = base_opt.minimize(loss0)

        # test graph
        embeddings = de.get_variable(
            "t2020-" + name + str(id),
            key_dtype=k_dtype,
            value_dtype=d_dtype,
            devices=_get_devices() * num_shards,
            initializer=1.0,
            dim=dim,
        )
        self.device_check(embeddings)
        init_ids = constant_op.constant(raw_init_ids, dtype=k_dtype)
        init_vals = constant_op.constant(raw_init_vals, dtype=d_dtype)
        init_op = embeddings.upsert(init_ids, init_vals)
        self.evaluate(init_op)

        test_var, trainable = de.embedding_lookup([embeddings],
                                                  ids,
                                                  return_trainable=True)
        pred1 = math_ops.matmul(test_var, x)
        loss1 = pred1 * pred1

        test_opt_op = test_opt.minimize(loss1, var_list=[trainable])

        self.evaluate(variables.global_variables_initializer())

        for _ in range(run_step):
          sess.run(base_opt_op)

        # Fetch params to validate initial values
        self.assertAllCloseAccordingToType(raw_init_vals[raw_ids],
                                           self.evaluate(test_var))
        # Run `run_step` step of sgd
        for _ in range(run_step):
          sess.run(test_opt_op)

        table_var = embeddings.lookup(ids)
        # Validate updated params
        self.assertAllCloseAccordingToType(
            self.evaluate(base_var)[raw_ids],
            self.evaluate(table_var),
            msg="Cond:{},{},{},{},{},{}".format(num_shards, k_dtype, d_dtype,
                                                initial_mode, dim, run_step),
        )