示例#1
0
 def testHessianVectorProduct(self):
     # Manually compute the Hessian explicitly for a low-dimensional problem
     # and check that HessianVectorProduct matches multiplication by the
     # explicit Hessian.
     # Specifically, the Hessian of f(x) = x^T A x is
     # H = A + A^T.
     # We expect HessianVectorProduct(f(x), x, v) to be H v.
     m = 4
     rng = np.random.RandomState([1, 2, 3])
     mat_value = rng.randn(m, m).astype("float32")
     v_value = rng.randn(m, 1).astype("float32")
     x_value = rng.randn(m, 1).astype("float32")
     hess_value = mat_value + mat_value.T
     hess_v_value = np.dot(hess_value, v_value)
     for use_gpu in [False, True]:
         with self.test_session(use_gpu=use_gpu):
             mat = constant_op.constant(mat_value)
             v = constant_op.constant(v_value)
             x = constant_op.constant(x_value)
             mat_x = math_ops.matmul(mat, x, name="Ax")
             x_mat_x = math_ops.matmul(array_ops.transpose(x),
                                       mat_x,
                                       name="xAx")
             hess_v = gradients._hessian_vector_product(x_mat_x, [x],
                                                        [v])[0]
             hess_v_actual = hess_v.eval()
         self.assertAllClose(hess_v_value, hess_v_actual)
示例#2
0
    def compute_hessian(self, objective, argument):
        if not isinstance(argument, list):
            argA = tf.zeros_like(argument)
            tfhess = _hessian_vector_product(objective, [argument], [argA])

            def hess(x, a):
                feed_dict = {argument: x, argA: a}
                return self._session.run(tfhess[0], feed_dict)

        else:
            argA = [tf.zeros_like(arg) for arg in argument]
            tfhess = _hessian_vector_product(objective, argument, argA)

            def hess(x, a):
                feed_dict = {i: d for i, d in zip(argument + argA, x + a)}
                return self._session.run(tfhess, feed_dict)

        return hess
示例#3
0
    def compute_hessian(self, objective, argument):
        if not isinstance(argument, list):
            argA = tf.Variable(tf.zeros(tf.shape(argument)))
            tfhess = _hessian_vector_product(objective, [argument], [argA])

            def hess(x, a):
                feed_dict = {argument: x, argA: a}
                return self._session.run(tfhess[0], feed_dict)

        else:
            argA = [tf.Variable(tf.zeros(tf.shape(arg)))
                    for arg in argument]
            tfhess = _hessian_vector_product(objective, argument, argA)

            def hess(x, a):
                feed_dict = {i: d for i, d in zip(argument+argA, x+a)}
                return self._session.run(tfhess, feed_dict)

        return hess
    def _construct_laszlo_operator_batched(self, dtype=np.float32):
        L = self.get_my_model().total_loss
        ws = self.get_my_model().trainable_weights
        X, y, sample_weights = self.get_my_model()._feed_inputs[0], \
                               self.get_my_model()._feed_targets[0], self.get_my_model()._feed_sample_weights[0]
        bs = self.batch_size

        shapes = [K.int_shape(w) for w in ws]
        dim = np.sum([np.prod(s) for s in shapes])
        shape = (dim, dim)
        linear_operator = collections.namedtuple(
            "LinearOperator", ["shape", "dtype", "apply", "apply_adjoint"])

        v_vect = tf.placeholder(tf.float32, [dim, 1])
        v_reshaped = []
        cur = 0
        for s in shapes:
            v_reshaped.append(K.reshape(v_vect[cur:np.prod(s) + cur], s))
            cur += np.prod(s)
        Hv_vect = _hessian_vector_product(L, ws, v_reshaped)

        sess = K.get_session()

        # noinspection SpellCheckingInspection
        def apply_cpu(v):
            res = [0 for _ in ws]
            for id1 in range(self.X.shape[0] // bs):
                x_batch = self.X[id1 * bs:(id1 + 1) * bs].astype(dtype)
                y_batch = self.y[id1 * bs:(id1 + 1) * bs].astype(dtype)
                ress = sess.run(Hv_vect, feed_dict={v_vect: v,
                                                    X: x_batch,
                                                    y: y_batch,
                                                    sample_weights: np.ones(shape=(bs,))
                                                    })
                for id2 in range(len(ws)):
                    res[id2] += bs * ress[id2].reshape(-1, 1)

            return np.concatenate(res, axis=0) / self.X.shape[0]

        def apply(v):
            return tf.py_func(apply_cpu, [v], tf.float32)

        return linear_operator(
            apply=apply,
            apply_adjoint=apply,
            dtype=dtype,
            shape=shape)
示例#5
0
 def testHessianVectorProduct(self):
   # Manually compute the Hessian explicitly for a low-dimensional problem
   # and check that HessianVectorProduct matches multiplication by the
   # explicit Hessian.
   # Specifically, the Hessian of f(x) = x^T A x is
   # H = A + A^T.
   # We expect HessianVectorProduct(f(x), x, v) to be H v.
   m = 4
   rng = np.random.RandomState([1, 2, 3])
   mat_value = rng.randn(m, m).astype("float32")
   v_value = rng.randn(m, 1).astype("float32")
   x_value = rng.randn(m, 1).astype("float32")
   hess_value = mat_value + mat_value.T
   hess_v_value = np.dot(hess_value, v_value)
   for use_gpu in [False, True]:
     with self.test_session(use_gpu=use_gpu):
       mat = constant_op.constant(mat_value)
       v = constant_op.constant(v_value)
       x = constant_op.constant(x_value)
       mat_x = math_ops.matmul(mat, x, name="Ax")
       x_mat_x = math_ops.matmul(array_ops.transpose(x), mat_x, name="xAx")
       hess_v = gradients._hessian_vector_product(x_mat_x, [x], [v])[0]
       hess_v_actual = hess_v.eval()
     self.assertAllClose(hess_v_value, hess_v_actual)
示例#6
0
def _construct_linear_operator_batched_scipy(L,
                                             ws,
                                             data,
                                             bs,
                                             X,
                                             y,
                                             sample_weights,
                                             n_classes,
                                             dtype=np.float32):
    # A bit of painful configuration
    shapes = [K.int_shape(w) for w in ws]
    dim = np.sum([np.prod(s) for s in shapes])
    shape = (dim, dim)
    v_vect = tf.placeholder(tf.float32, [
        dim,
    ])
    v_reshaped = []
    cur = 0

    # TODO: Learning phase?
    logger.info("Calculating shapes")
    for s in shapes:
        v_reshaped.append(K.reshape(v_vect[cur:np.prod(s) + cur], s))
        cur += np.prod(s)
    logger.info("Consturcting hvp op")
    vector_product = _hessian_vector_product(L, ws, v_reshaped)
    logger.info("Done constructing hvp op")

    # Apply
    def apply_cpu(v):
        sess = K.get_session()
        if isinstance(data, list):
            res = [np.zeros(K.int_shape(vv), dtype=np.float32) for vv in ws]
            nb = int((bs - 1 + data[0].shape[0]) / bs)
            n = 0
            for id in tqdm.tqdm(range(nb), total=nb):
                x_batch = data[0][id * bs:(id + 1) * bs].astype(dtype)
                y_batch = data[1][id * bs:(id + 1) * bs].astype(dtype)
                n += len(x_batch)
                if y_batch.shape[-1] != n_classes:
                    y_batch = np_utils.to_categorical(y_batch, n_classes)
                fd = {
                    v_vect: v,
                    X: x_batch,
                    y: y_batch,
                    K.learning_phase(): 1.0,
                    sample_weights: np.ones(shape=(len(x_batch), )),
                }
                ress = sess.run(vector_product, feed_dict=fd)
                for id2 in range(len(ws)):
                    # res[id2] += bs * ress[id2].reshape(-1, 1)
                    g = ress[id2]
                    # NOTE: Could be optimized
                    if isinstance(vector_product[id2], tf.IndexedSlices):
                        vals, indices = g.values, g.indices
                        res[id2][indices] += len(x_batch) * vals
                    else:
                        res[id2] += len(x_batch) * g

            return np.concatenate([g.reshape(-1, 1) for g in res], axis=0) / n
            # return np.concatenate(res, axis=0) / data[0].shape[0]
        else:
            raise NotImplementedError()

    A = linalg.LinearOperator(shape, matvec=apply_cpu)

    return A, apply_cpu, dim