コード例 #1
0
 def test_GraphKeys(self):
   v0 = de.Variable(key_dtype=dtypes.int64,
                    value_dtype=dtypes.float32,
                    initializer=0.0,
                    name="v0")
   v1 = de.Variable(key_dtype=dtypes.int64,
                    value_dtype=dtypes.float32,
                    initializer=0.0,
                    name="v1",
                    trainable=False)
   v2 = de.get_variable(
       "v2",
       key_dtype=dtypes.int64,
       value_dtype=dtypes.float32,
       initializer=init_ops.zeros_initializer,
       dim=10,
   )
   v3 = de.get_variable("v3",
                        key_dtype=dtypes.int64,
                        value_dtype=dtypes.float32,
                        initializer=init_ops.zeros_initializer,
                        dim=10,
                        trainable=False)
   de_vars = ops.get_collection(de.GraphKeys.DYNAMIC_EMBEDDING_VARIABLES)
   self.assertSetEqual(set([v0, v1, v2, v3]), set(de_vars))
   de_trainable_vars = ops.get_collection(
       de.GraphKeys.TRAINABLE_DYNAMIC_EMBEDDING_VARIABLES)
   self.assertAllEqual(set([v0, v2]), set(de_trainable_vars))
コード例 #2
0
def create_slots(primary, init, slot_name, op_name):
    """Helper function for creating a slot variable for statefull optimizers."""
    params_var_, params_ids_ = primary.params, primary.ids

    scope_store = variable_scope._get_default_variable_store()
    full_name = params_var_.name + "/" + op_name + "/" + slot_name
    if full_name not in scope_store._vars:
        with ops.colocate_with(primary, ignore_existing=True):
            slot_variable_ = de.Variable(
                name=full_name,
                key_dtype=params_var_.key_dtype,
                value_dtype=params_var_.value_dtype,
                dim=params_var_.dim,
                devices=params_var_.devices,
                partitioner=params_var_.partition_fn,
                initializer=init,
                init_size=params_var_.init_size,
                trainable=False,
                checkpoint=params_var_.checkpoint,
            )

        scope_store._vars[full_name] = slot_variable_

    slot_trainable = None
    _, slot_trainable = de.embedding_lookup(
        params=scope_store._vars[full_name],
        ids=params_ids_,
        name=slot_name,
        return_trainable=True,
    )

    return slot_trainable
コード例 #3
0
def create_slots(primary, init, slot_name, op_name, bp_v2):
    """Helper function for creating a slot variable for statefull optimizers."""
    params_var_, params_ids_ = primary.params, primary.ids

    scope_store = variable_scope._get_default_variable_store()
    full_name = params_var_.name + "/" + op_name + "/" + slot_name
    if full_name not in scope_store._vars:
        with ops.colocate_with(primary, ignore_existing=True):
            slot_variable_ = de.Variable(
                name=full_name,
                key_dtype=params_var_.key_dtype,
                value_dtype=params_var_.value_dtype,
                dim=params_var_.dim,
                devices=params_var_.devices,
                partitioner=params_var_.partition_fn,
                initializer=init,
                kv_creator=params_var_.kv_creator,
                trainable=False,
                checkpoint=params_var_.checkpoint,
                bp_v2=bp_v2 if bp_v2 is not None else params_var_.bp_v2,
            )

        scope_store._vars[full_name] = slot_variable_
        # Record the optimizer Variable into trace.
        primary._optimizer_vars.append(slot_variable_)

    slot_trainable = None
    if context.executing_eagerly():
        slot_tw_name = slot_name + '-' + str(optimizer_v2._var_key(primary))
    else:
        # In graph mode of former version, It only uses slot_name as name to
        # trainable wrappers of slots. So here set it the name to slot_name
        # for forward compatibility.
        slot_tw_name = slot_name
    if isinstance(primary, de.shadow_ops.ShadowVariable):
        slot_trainable = de.shadow_ops.ShadowVariable(
            params=scope_store._vars[full_name],
            ids=primary.ids,
            exists=primary.exists,
            name=full_name,
            trainable=False,
        )
    else:
        _, slot_trainable = de.embedding_lookup(
            params=scope_store._vars[full_name],
            ids=params_ids_,
            name=slot_tw_name,
            return_trainable=True,
        )

    return slot_trainable
コード例 #4
0
    def test_saving_restoring_checkpoint(self):

        logdir = _test_dir(self.get_temp_dir(),
                           "test_saving_restoring_checkpoint")
        with ops.Graph().as_default():
            gstep = training_util.create_global_step()
            do_step = state_ops.assign_add(gstep, 1)

            v0 = variables.Variable(10.0, name="v0")
            v1 = variables.Variable(20.0, name="v1")

            target_values = [[0.0], [1.0], [2.0]]
            keys = array_ops.placeholder(dtypes.int64)
            values = constant_op.constant(target_values, dtypes.float32)

            table = de.Variable(
                key_dtype=dtypes.int64,
                value_dtype=dtypes.float32,
                initializer=-1.0,
                name="m100",
                dim=1,
            )
            upsert_op = table.upsert(keys, values)
            lookup_op = table.lookup(keys)
            size_op = table.size()
            with monitored_session.MonitoredTrainingSession(
                    config=default_config, is_chief=True,
                    checkpoint_dir=logdir) as sess:
                self.assertEqual(0, sess.run(gstep))
                self.assertEqual(1, sess.run(do_step))
                self.assertEqual(2, sess.run(do_step))

                # Check that the parameter nodes have been initialized.
                self.assertEqual(10.0, sess.run(v0))
                self.assertEqual(20.0, sess.run(v1))
                self.assertAllEqual(0, sess.run(size_op))
                sess.run(upsert_op, feed_dict={keys: [0, 1, 2]})
                self.assertAllEqual(3, sess.run(size_op))
                self.device_check(table)

            # A restart will find the checkpoint and recover automatically.
            with monitored_session.MonitoredTrainingSession(
                    config=default_config, is_chief=True,
                    checkpoint_dir=logdir) as sess:
                self.assertEqual(2, sess.run(gstep))
                self.assertAllEqual(3, sess.run(table.size()))
                self.assertAllEqual(
                    target_values,
                    sess.run(lookup_op, feed_dict={keys: [0, 1, 2]}))

                self.device_check(table)
def create_slots(primary, init, slot_name, op_name):
  """Helper function for creating a slot variable for statefull optimizers."""
  # lwk @primary是一个de.TrainableWrapper
  # lwk 所以@params_var_是de.Variable,params_ids_是特征ID
  params_var_, params_ids_ = primary.params, primary.ids

  scope_store = variable_scope._get_default_variable_store()
  full_name = params_var_.name + "/" + op_name + "/" + slot_name
  if full_name not in scope_store._vars:
    with ops.colocate_with(primary, ignore_existing=True):
      # lwk de.Variable的slot variable也是一个de.Variable,也就是是一个哈希表
      slot_variable_ = de.Variable(
          name=full_name,
          key_dtype=params_var_.key_dtype,
          value_dtype=params_var_.value_dtype,
          dim=params_var_.dim, # 维度一样
          devices=params_var_.devices,
          partitioner=params_var_.partition_fn,
          initializer=init,
          trainable=False,
          checkpoint=params_var_.checkpoint,
      )

    scope_store._vars[full_name] = slot_variable_

  # lwk 所以动态哈希表的原理,本质上就是将lookup的结果包装成一个variable
  # lwk 这个variable包含了哈希表和ids,优化器也使用这些信息,来构造一个对应的slot variable
  slot_trainable = None
  _, slot_trainable = de.embedding_lookup(
      params=scope_store._vars[full_name],
      ids=params_ids_,
      name=slot_name,
      return_trainable=True,
  )

  return slot_trainable
    def test_save_restore_only_table(self):
        save_dir = os.path.join(self.get_temp_dir(), "save_restore")
        save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")

        with self.session(
                config=default_config,
                graph=ops.Graph(),
                use_gpu=test_util.is_gpu_available(),
        ) as sess:
            v0 = variables.Variable(10.0, name="v0")
            v1 = variables.Variable(20.0, name="v1")

            default_val = -1
            keys = constant_op.constant([0, 1, 2], dtypes.int64)
            values = constant_op.constant([[0], [1], [2]], dtypes.int32)
            table = de.Variable(
                dtypes.int64,
                dtypes.int32,
                name="t1",
                initializer=default_val,
                checkpoint=True,
            )

            save = saver.Saver([table])
            self.evaluate(variables.global_variables_initializer())

            # Check that the parameter nodes have been initialized.
            self.assertEqual(10.0, self.evaluate(v0))
            self.assertEqual(20.0, self.evaluate(v1))

            self.assertAllEqual(0, self.evaluate(table.size()))
            self.evaluate(table.upsert(keys, values))
            self.assertAllEqual(3, self.evaluate(table.size()))

            val = save.save(sess, save_path)
            self.assertIsInstance(val, six.string_types)
            self.assertEqual(save_path, val)
            del table

        with self.session(
                config=default_config,
                graph=ops.Graph(),
                use_gpu=test_util.is_gpu_available(),
        ) as sess:
            default_val = -1
            table = de.Variable(
                dtypes.int64,
                dtypes.int32,
                name="t1",
                initializer=default_val,
                checkpoint=True,
            )
            self.evaluate(
                table.upsert(
                    constant_op.constant([0, 2], dtypes.int64),
                    constant_op.constant([[12], [24]], dtypes.int32),
                ))
            self.assertAllEqual(2, self.evaluate(table.size()))

            save = saver.Saver([table._tables[0]])

            # Restore the saved values in the parameter nodes.
            save.restore(sess, save_path)
            # Check that the parameter nodes have been restored.

            self.assertAllEqual(3, self.evaluate(table.size()))

            remove_keys = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64)
            output = table.lookup(remove_keys)
            self.assertAllEqual([[0], [1], [2], [-1], [-1]],
                                self.evaluate(output))
            del table
    def test_save_restore(self):
        save_dir = os.path.join(self.get_temp_dir(), "save_restore")
        save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")

        with self.session(config=default_config, graph=ops.Graph()) as sess:
            v0 = variables.Variable(10.0, name="v0")
            v1 = variables.Variable(20.0, name="v1")

            keys = constant_op.constant([0, 1, 2], dtypes.int64)
            values = constant_op.constant([[0.0], [1.0], [2.0]],
                                          dtypes.float32)
            table = de.Variable(
                key_dtype=dtypes.int64,
                value_dtype=dtypes.float32,
                initializer=-1.0,
                name="t1",
                dim=1,
            )

            save = saver.Saver(var_list=[v0, v1, table])
            self.evaluate(variables.global_variables_initializer())

            # Check that the parameter nodes have been initialized.
            self.assertEqual(10.0, self.evaluate(v0))
            self.assertEqual(20.0, self.evaluate(v1))

            self.assertAllEqual(0, self.evaluate(table.size()))
            self.evaluate(table.upsert(keys, values))
            self.assertAllEqual(3, self.evaluate(table.size()))

            val = save.save(sess, save_path)
            self.assertIsInstance(val, six.string_types)
            self.assertEqual(save_path, val)

            del table

        with self.session(config=default_config, graph=ops.Graph()) as sess:
            v0 = variables.Variable(-1.0, name="v0")
            v1 = variables.Variable(-1.0, name="v1")
            table = de.Variable(
                name="t1",
                key_dtype=dtypes.int64,
                value_dtype=dtypes.float32,
                initializer=-1.0,
                dim=1,
                checkpoint=True,
            )
            self.evaluate(
                table.upsert(
                    constant_op.constant([0, 1], dtypes.int64),
                    constant_op.constant([[12.0], [24.0]], dtypes.float32),
                ))
            size_op = table.size()
            self.assertAllEqual(2, self.evaluate(size_op))

            save = saver.Saver(var_list=[v0, v1, table])

            # Restore the saved values in the parameter nodes.
            save.restore(sess, save_path)
            # Check that the parameter nodes have been restored.
            self.assertEqual([10.0], self.evaluate(v0))
            self.assertEqual([20.0], self.evaluate(v1))

            self.assertAllEqual(3, self.evaluate(table.size()))

            remove_keys = constant_op.constant([5, 0, 1, 2, 6], dtypes.int64)
            output = table.lookup(remove_keys)
            self.assertAllEqual([[-1.0], [0.0], [1.0], [2.0], [-1.0]],
                                self.evaluate(output))

            del table