def commonly_apply_update_verify_v2(self): if not context.executing_eagerly(): self.skipTest('Skip graph mode test.') first_inputs = np.array(range(6), dtype=np.int64) second_inputs = np.array(range(3, 9), dtype=np.int64) overdue_features = np.array(range(3), dtype=np.int64) updated_features = np.array(range(3, 9), dtype=np.int64) all_features = np.array(range(9), dtype=np.int64) with self.session(config=default_config): var = de.get_variable('sp_var', key_dtype=dtypes.int64, value_dtype=dtypes.float32, initializer=-0.1, dim=2) embed_w, trainable = de.embedding_lookup(var, first_inputs, return_trainable=True, name='re0212') policy = de.TimestampRestrictPolicy(var) self.assertAllEqual(policy.status.size(), 0) policy.apply_update(first_inputs) self.assertAllEqual(policy.status.size(), len(first_inputs)) time.sleep(1) policy.apply_update(second_inputs) self.assertAllEqual(policy.status.size(), len(all_features)) keys, tstp = policy.status.export() kvs = sorted(dict(zip(keys.numpy(), tstp.numpy())).items()) tstp = np.array([x[1] for x in kvs]) for x in tstp[overdue_features]: for y in tstp[updated_features]: self.assertLess(x, y)
def commonly_apply_update_verify(self): first_inputs = np.array(range(3), dtype=np.int64) second_inputs = np.array(range(1, 4), dtype=np.int64) overdue_features = np.array([0], dtype=np.int64) updated_features = np.array(range(1, 4), dtype=np.int64) with session.Session(config=default_config) as sess: ids = array_ops.placeholder(dtypes.int64) var = de.get_variable('sp_var', key_dtype=ids.dtype, value_dtype=dtypes.float32, initializer=-0.1, init_size=256, dim=2) embed_w, trainable = de.embedding_lookup(var, ids, return_trainable=True, name='wf7843') policy = de.TimestampRestrictPolicy(var) update_op = policy.apply_update(ids) self.assertAllEqual(sess.run(policy.status.size()), 0) sess.run(update_op, feed_dict={ids: first_inputs}) self.assertAllEqual(sess.run(policy.status.size()), 3) time.sleep(1) sess.run(update_op, feed_dict={ids: second_inputs}) self.assertAllEqual(sess.run(policy.status.size()), 4) keys, tstp = sess.run(policy.status.export()) kvs = sorted(dict(zip(keys, tstp)).items()) tstp = np.array([x[1] for x in kvs]) for x in tstp[overdue_features]: for y in tstp[updated_features]: self.assertLess(x, y)