コード例 #1
0
    def commonly_apply_update_verify_v2(self):
        if not context.executing_eagerly():
            self.skipTest('Skip graph mode test.')

        first_inputs = np.array(range(6), dtype=np.int64)
        second_inputs = np.array(range(3, 9), dtype=np.int64)
        overdue_features = np.array(range(3), dtype=np.int64)
        updated_features = np.array(range(3, 9), dtype=np.int64)
        all_features = np.array(range(9), dtype=np.int64)

        with self.session(config=default_config):
            var = de.get_variable('sp_var',
                                  key_dtype=dtypes.int64,
                                  value_dtype=dtypes.float32,
                                  initializer=-0.1,
                                  dim=2)
            embed_w, trainable = de.embedding_lookup(var,
                                                     first_inputs,
                                                     return_trainable=True,
                                                     name='re0212')
            policy = de.TimestampRestrictPolicy(var)

            self.assertAllEqual(policy.status.size(), 0)
            policy.apply_update(first_inputs)
            self.assertAllEqual(policy.status.size(), len(first_inputs))
            time.sleep(1)
            policy.apply_update(second_inputs)
            self.assertAllEqual(policy.status.size(), len(all_features))

            keys, tstp = policy.status.export()
            kvs = sorted(dict(zip(keys.numpy(), tstp.numpy())).items())
            tstp = np.array([x[1] for x in kvs])
            for x in tstp[overdue_features]:
                for y in tstp[updated_features]:
                    self.assertLess(x, y)
コード例 #2
0
    def commonly_apply_update_verify(self):
        first_inputs = np.array(range(3), dtype=np.int64)
        second_inputs = np.array(range(1, 4), dtype=np.int64)
        overdue_features = np.array([0], dtype=np.int64)
        updated_features = np.array(range(1, 4), dtype=np.int64)
        with session.Session(config=default_config) as sess:
            ids = array_ops.placeholder(dtypes.int64)
            var = de.get_variable('sp_var',
                                  key_dtype=ids.dtype,
                                  value_dtype=dtypes.float32,
                                  initializer=-0.1,
                                  init_size=256,
                                  dim=2)
            embed_w, trainable = de.embedding_lookup(var,
                                                     ids,
                                                     return_trainable=True,
                                                     name='wf7843')
            policy = de.TimestampRestrictPolicy(var)
            update_op = policy.apply_update(ids)

            self.assertAllEqual(sess.run(policy.status.size()), 0)
            sess.run(update_op, feed_dict={ids: first_inputs})
            self.assertAllEqual(sess.run(policy.status.size()), 3)
            time.sleep(1)
            sess.run(update_op, feed_dict={ids: second_inputs})
            self.assertAllEqual(sess.run(policy.status.size()), 4)

            keys, tstp = sess.run(policy.status.export())
            kvs = sorted(dict(zip(keys, tstp)).items())
            tstp = np.array([x[1] for x in kvs])
            for x in tstp[overdue_features]:
                for y in tstp[updated_features]:
                    self.assertLess(x, y)