Exemplo n.º 1
0
    def query_lse(self, context):
        '''
        Returns the next active query on the function in a particular context 
        using level set estimation.
        We here implement the straddle algorithm from 
        B. Bryan, R. C. Nichol, C. R. Genovese, J. Schneider, C. J. Miller, and L. Wasserman, 
        "Active learning for identifying function threshold boundaries," in NIPS, 2006.
        '''
        x0, x0context = helper.find_closest_positive_context_param(
            context, self.xx, self.yy, self.func.param_idx,
            self.func.context_idx)

        total_context = helper.tuple_context_to_total_context(context)

        def ac_f(x):
            if x.ndim == 1:
                x = x[None, :]
            x = np.hstack((x, np.tile(total_context, (x.shape[0], 1))))
            mu, var = self.model.predict(x)
            return -1.96 * np.sqrt(var) + np.abs(mu)

        x_star, _ = helper.global_minimize(
            ac_f, None, self.func.x_range[:, self.func.param_idx], 10000, x0)

        return np.hstack((x_star, total_context))
Exemplo n.º 2
0
    def query_best_prob(self, context):
        '''
        Returns the input that has the highest probability to be in the super 
        level set for a given context.
        '''
        x0, x0context = helper.find_closest_positive_context_param(
            context, self.xx, self.yy, self.func.param_idx,
            self.func.context_idx)
        self.model = self.model

        def ac_f(x):
            if x.ndim == 1:
                x = x[None, :]
            x = np.hstack((x, np.tile(context, (x.shape[0], 1))))
            mu, var = self.model.predict(x)
            return (-mu) / np.sqrt(var)

        def ac_fg(x):
            if x.ndim == 1:
                x = x[None, :]
            x = np.hstack((x, np.tile(context, (x.shape[0], 1))))
            mu, var = self.model.predict(x)
            dmdx, dvdx = self.model.predictive_gradients(x)
            dmdx = dmdx[0, :, 0]
            dvdx = dvdx[0, :]
            f = (-mu) / np.sqrt(var)
            g = (-np.sqrt(var) * dmdx - 0.5 *
                 (-mu) * dvdx / np.sqrt(var)) / var
            return f[0, 0], g[0, self.func.param_idx]

        x0 = np.vstack(
            (x0, self.xx[np.squeeze(self.yy) > 0][:, self.func.param_idx]))
        x_star, y_star = helper.global_minimize(
            ac_f, ac_fg, self.func.x_range[:, self.func.param_idx], 10000, x0)
        print 'best beta=', -y_star
        self.best_beta = -y_star
        self.beta = norm.ppf(self.betalambda * norm.cdf(self.best_beta))
        if self.best_beta < 0:
            raw_input(
                'Warning! Cannot find any parameter to be super level set \
                   with more than 0.5 probability. Are you sure to continue?')
        if self.beta > self.best_beta:
            raise ValueError('Beta cannot be larger than best beta.')
        return np.hstack((x_star, context))
Exemplo n.º 3
0
    def query(self, context):
        x0, x0context = helper.find_closest_positive_context_param(
            context, self.xx, self.yy, self.func.param_idx,
            self.func.context_idx)
        g = kb.gradients(self.model.outputs[0], self.model.inputs)
        gfn = kb.function(self.model.inputs, g)

        def fn(param):
            x = np.hstack((param, np.tile(context, (param.shape[0], 1))))
            return -self.model.predict(x).astype(np.float64)

        def fgfn(param):
            x = np.hstack((param, context))
            return -self.model.predict(np.array([x]))[0].astype(np.float64), \
                   -gfn([np.array([x])])[0][0,
                                            self.func.param_idx].astype(np.float64)

        x_range = self.func.x_range
        guesses = helper.grid_around_point(x0, 0.5 * (x_range[1] - x_range[0]),
                                           5, x_range)
        x_star, y_star = helper.global_minimize(
            fn, fgfn, x_range[:, self.func.param_idx], 10000, guesses)
        print('x_star={}, y_star={}'.format(x_star, y_star))
        return np.hstack((x_star, context))