Ejemplo n.º 1
0
        def _eval(v):
            """The evaluation function.

            Args:
                v (numpy.ndarray): The vector to be multiplied with Hessian.

            Returns:
                numpy.ndarray: The product of Hessian of function f and v.

            """
            xs = tuple(self._target.flat_to_params(v))
            ret = sliced_fun(self._hvp_fun['f_hx_plain'], self._num_slices)(
                inputs, xs) + self._reg_coeff * v
            return ret
Ejemplo n.º 2
0
    def constraint_val(self, inputs, extra_inputs=None):
        """Constraint value.

        Args:
            inputs (list[numpy.ndarray]): A list inputs, which could be
                subsampled if needed. It is assumed that the first dimension
                of these inputs should correspond to the number of data points
            extra_inputs (list[numpy.ndarray]): A list of extra inputs which
                should not be subsampled.

        Returns:
            float: Constraint value.

        """
        inputs = tuple(inputs)
        if extra_inputs is None:
            extra_inputs = tuple()
        return sliced_fun(self._opt_fun['f_constraint'],
                          self._num_slices)(inputs, extra_inputs)
Ejemplo n.º 3
0
    def constraint_val(self, inputs, extra_inputs=None):
        """Calculate the constraint value.

        Parameters
        ----------
        inputs :
            A list of symbolic variables as inputs, which could be subsampled if needed. It is assumed
            that the first dimension of these inputs should correspond to the number of data points.
        extra_inputs : optional
            A list of symbolic variables as extra inputs which should not be subsampled.

        Returns
        -------
        constraint_value : float
            The value of the constrained variable.
        """
        inputs = tuple(inputs)
        if extra_inputs is None:
            extra_inputs = tuple()
        return sliced_fun(self._opt_fun["f_constraint"],
                          self._num_slices)(inputs, extra_inputs)
Ejemplo n.º 4
0
    def optimize(self,
                 inputs,
                 extra_inputs=None,
                 subsample_grouped_inputs=None,
                 name=None):
        """Optimize the function.

        Args:
            inputs (list[numpy.ndarray]): A list inputs, which could be
                subsampled if needed. It is assumed that the first dimension
                of these inputs should correspond to the number of data points
            extra_inputs (list[numpy.ndarray]): A list of extra inputs which
                should not be subsampled.
            subsample_grouped_inputs (list[numpy.ndarray]): Subsampled inputs
                to be used when subsample_factor is less than one.
            name (str): The name argument for tf.name_scope.

        """
        with tf.name_scope(
                name,
                'optimize',
                values=[inputs, extra_inputs, subsample_grouped_inputs]):
            prev_param = np.copy(self._target.get_param_values())
            inputs = tuple(inputs)
            if extra_inputs is None:
                extra_inputs = tuple()

            subsample_inputs = inputs
            if self._subsample_factor < 1:
                if subsample_grouped_inputs is None:
                    subsample_grouped_inputs = [inputs]
                subsample_inputs = tuple()
                for inputs_grouped in subsample_grouped_inputs:
                    n_samples = len(inputs_grouped[0])
                    inds = np.random.choice(n_samples,
                                            int(n_samples *
                                                self._subsample_factor),
                                            replace=False)
                    subsample_inputs += tuple(
                        [x[inds] for x in inputs_grouped])

            logger.log(
                ('Start CG optimization: '
                 '#parameters: %d, #inputs: %d, #subsample_inputs: %d') %
                (len(prev_param), len(inputs[0]), len(subsample_inputs[0])))

            logger.log('computing loss before')
            loss_before = sliced_fun(self._opt_fun['f_loss'],
                                     self._num_slices)(inputs, extra_inputs)

            logger.log('computing gradient')
            flat_g = sliced_fun(self._opt_fun['f_grad'],
                                self._num_slices)(inputs, extra_inputs)
            logger.log('gradient computed')

            logger.log('computing descent direction')
            hx = self._hvp_approach.build_eval(subsample_inputs + extra_inputs)
            descent_direction = cg(hx, flat_g, cg_iters=self._cg_iters)

            initial_step_size = np.sqrt(
                2.0 * self._max_constraint_val *
                (1. / (descent_direction.dot(hx(descent_direction)) + 1e-8)))
            if np.isnan(initial_step_size):
                initial_step_size = 1.
            flat_descent_step = initial_step_size * descent_direction

            logger.log('descent direction computed')

            n_iter = 0
            for n_iter, ratio in enumerate(self._backtrack_ratio**np.arange(
                    self._max_backtracks)):  # yapf: disable
                cur_step = ratio * flat_descent_step
                cur_param = prev_param - cur_step
                self._target.set_param_values(cur_param)
                loss, constraint_val = sliced_fun(
                    self._opt_fun['f_loss_constraint'],
                    self._num_slices)(inputs, extra_inputs)
                if (loss < loss_before
                        and constraint_val <= self._max_constraint_val):
                    break
            if (np.isnan(loss) or np.isnan(constraint_val)
                    or loss >= loss_before or constraint_val >=
                    self._max_constraint_val) and not self._accept_violation:
                logger.log(
                    'Line search condition violated. Rejecting the step!')
                if np.isnan(loss):
                    logger.log('Violated because loss is NaN')
                if np.isnan(constraint_val):
                    logger.log('Violated because constraint %s is NaN' %
                               self._constraint_name)
                if loss >= loss_before:
                    logger.log('Violated because loss not improving')
                if constraint_val >= self._max_constraint_val:
                    logger.log('Violated because constraint %s is violated' %
                               self._constraint_name)
                self._target.set_param_values(prev_param)
            logger.log('backtrack iters: %d' % n_iter)
            logger.log('optimization finished')
Ejemplo n.º 5
0
 def eval(x):
     xs = tuple(self.target.flat_to_params(x, trainable=True))
     ret = sliced_fun(self.opt_fun["f_hx_plain"], self._num_slices)(
         inputs, xs) + self.reg_coeff * x
     return ret
Ejemplo n.º 6
0
    def optimize(self,
                 inputs,
                 extra_inputs=None,
                 subsample_grouped_inputs=None,
                 name=None):
        with tf.name_scope(
                name,
                "optimize",
                values=[inputs, extra_inputs, subsample_grouped_inputs]):
            prev_param = np.copy(self._target.get_param_values(trainable=True))
            inputs = tuple(inputs)
            if extra_inputs is None:
                extra_inputs = tuple()

            if self._subsample_factor < 1:
                if subsample_grouped_inputs is None:
                    subsample_grouped_inputs = [inputs]
                subsample_inputs = tuple()
                for inputs_grouped in subsample_grouped_inputs:
                    n_samples = len(inputs_grouped[0])
                    inds = np.random.choice(
                        n_samples,
                        int(n_samples * self._subsample_factor),
                        replace=False)
                    subsample_inputs += tuple(
                        [x[inds] for x in inputs_grouped])
            else:
                subsample_inputs = inputs

            logger.log(
                ("Start CG optimization: "
                 "#parameters: %d, #inputs: %d, #subsample_inputs: %d") %
                (len(prev_param), len(inputs[0]), len(subsample_inputs[0])))

            logger.log("computing loss before")
            loss_before = sliced_fun(self._opt_fun["f_loss"],
                                     self._num_slices)(inputs, extra_inputs)
            logger.log("performing update")

            logger.log("computing gradient")
            flat_g = sliced_fun(self._opt_fun["f_grad"],
                                self._num_slices)(inputs, extra_inputs)
            logger.log("gradient computed")

            logger.log("computing descent direction")
            hx = self._hvp_approach.build_eval(subsample_inputs + extra_inputs)

            descent_direction = krylov.cg(hx, flat_g, cg_iters=self._cg_iters)

            initial_step_size = np.sqrt(
                2.0 * self._max_constraint_val *
                (1. / (descent_direction.dot(hx(descent_direction)) + 1e-8)))
            if np.isnan(initial_step_size):
                initial_step_size = 1.
            flat_descent_step = initial_step_size * descent_direction

            logger.log("descent direction computed")

            n_iter = 0
            for n_iter, ratio in enumerate(self._backtrack_ratio**np.arange(
                    self._max_backtracks)):  # yapf: disable
                cur_step = ratio * flat_descent_step
                cur_param = prev_param - cur_step
                self._target.set_param_values(cur_param, trainable=True)
                loss, constraint_val = sliced_fun(
                    self._opt_fun["f_loss_constraint"],
                    self._num_slices)(inputs, extra_inputs)
                if self._debug_nan and np.isnan(constraint_val):
                    break
                if loss < loss_before and \
                   constraint_val <= self._max_constraint_val:
                    break
            if (np.isnan(loss) or np.isnan(constraint_val)
                    or loss >= loss_before or constraint_val >=
                    self._max_constraint_val) and not self._accept_violation:
                logger.log(
                    "Line search condition violated. Rejecting the step!")
                if np.isnan(loss):
                    logger.log("Violated because loss is NaN")
                if np.isnan(constraint_val):
                    logger.log("Violated because constraint %s is NaN" %
                               self._constraint_name)
                if loss >= loss_before:
                    logger.log("Violated because loss not improving")
                if constraint_val >= self._max_constraint_val:
                    logger.log("Violated because constraint %s is violated" %
                               self._constraint_name)
                self._target.set_param_values(prev_param, trainable=True)
            logger.log("backtrack iters: %d" % n_iter)
            logger.log("computing loss after")
            logger.log("optimization finished")
Ejemplo n.º 7
0
 def constraint_val(self, inputs, extra_inputs=None):
     inputs = tuple(inputs)
     if extra_inputs is None:
         extra_inputs = tuple()
     return sliced_fun(self._opt_fun["f_constraint"],
                       self._num_slices)(inputs, extra_inputs)
Ejemplo n.º 8
0
 def loss(self, inputs, extra_inputs=None):
     inputs = tuple(inputs)
     if extra_inputs is None:
         extra_inputs = tuple()
     return sliced_fun(self._opt_fun["f_loss"],
                       self._num_slices)(inputs, extra_inputs)
Ejemplo n.º 9
0
    def get_magnitudes(self,
                       directions,
                       inputs,
                       max_constraint_val=None,
                       extra_inputs=None,
                       subsample_grouped_inputs=None):
        if max_constraint_val is not None:
            self._max_constraint_val = max_constraint_val
        prev_param = np.copy(self._target.get_param_values(trainable=True))
        inputs = tuple(inputs)
        if extra_inputs is None:
            extra_inputs = tuple()

        if self._subsample_factor < 1:
            if subsample_grouped_inputs is None:
                subsample_grouped_inputs = [inputs]
            subsample_inputs = tuple()
            for inputs_grouped in subsample_grouped_inputs:
                n_samples = len(inputs_grouped[0])
                inds = np.random.choice(n_samples,
                                        int(n_samples *
                                            self._subsample_factor),
                                        replace=False)
                subsample_inputs += tuple([x[inds] for x in inputs_grouped])
        else:
            subsample_inputs = inputs

        Hx = self._hvp_approach.build_eval(subsample_inputs + extra_inputs)

        magnitudes = []
        constraint_vals = []
        for descent_direction in directions:
            initial_step_size = np.sqrt(
                2.0 * self._max_constraint_val *
                (1. / (descent_direction.dot(Hx(descent_direction)) + 1e-8)))
            if np.isnan(initial_step_size):
                initial_step_size = 1.
            flat_descent_step = initial_step_size * descent_direction

            n_iter = 0
            for n_iter, ratio in enumerate(self._backtrack_ratio**np.arange(
                    self._max_backtracks)):
                cur_step = ratio * flat_descent_step
                cur_param = prev_param - cur_step
                self._target.set_param_values(cur_param, trainable=True)
                constraint_val = sliced_fun(self._opt_fun["f_constraint"],
                                            self._num_slices)(inputs,
                                                              extra_inputs)
                if self._debug_nan and np.isnan(constraint_val):
                    import ipdb
                    ipdb.set_trace()
                if constraint_val <= self._max_constraint_val:
                    break
            if (np.isnan(constraint_val) or constraint_val >=
                    self._max_constraint_val) and not self._accept_violation:
                logger.log(
                    "Line search condition violated. Rejecting the step!")
                if np.isnan(constraint_val):
                    logger.log("Violated because constraint %s is NaN" %
                               self._constraint_name)
                if constraint_val >= self._max_constraint_val:
                    logger.log("Violated because constraint %s is violated" %
                               self._constraint_name)
                self._target.set_param_values(prev_param, trainable=True)
            # logger.log("backtrack iters: %d" % n_iter)
            # logger.log("final magnitude: " + str(-ratio*initial_step_size))
            logger.log("final kl: " + str(constraint_val))
            # logger.log("optimization finished")
            magnitudes.append(-ratio * initial_step_size)
            constraint_vals.append(constraint_val)
        return magnitudes, constraint_vals
Ejemplo n.º 10
0
    def get_magnitude(self,
                      direction,
                      inputs,
                      max_constraint_val=None,
                      extra_inputs=None,
                      subsample_grouped_inputs=None):
        """Calculate the update magnitude.

        Parameters
        ----------
        direction: :py:class:'tensorflow.Tensor'
            The gradient direction.
        inputs :
            A list of symbolic variables as inputs, which could be subsampled if needed. It is assumed
            that the first dimension of these inputs should correspond to the number of data points.
        max_constraint_val : float, optional
            The maximum value for the constrained variale.
        extra_inputs : optional
            A list of symbolic variables as extra inputs which should not be subsampled.
        subsample_grouped_inputs : optional
            The list of inputs that are needed to be subsampled.

        Returns
        -------
        magnitude : float
            The update magnitude.
        """
        if max_constraint_val is not None:
            self._max_constraint_val = max_constraint_val
        prev_param = np.copy(self._target.get_param_values(trainable=True))
        inputs = tuple(inputs)
        if extra_inputs is None:
            extra_inputs = tuple()

        if self._subsample_factor < 1:
            if subsample_grouped_inputs is None:
                subsample_grouped_inputs = [inputs]
            subsample_inputs = tuple()
            for inputs_grouped in subsample_grouped_inputs:
                n_samples = len(inputs_grouped[0])
                inds = np.random.choice(n_samples,
                                        int(n_samples *
                                            self._subsample_factor),
                                        replace=False)
                subsample_inputs += tuple([x[inds] for x in inputs_grouped])
        else:
            subsample_inputs = inputs

        Hx = self._hvp_approach.build_eval(subsample_inputs + extra_inputs)

        descent_direction = direction

        initial_step_size = np.sqrt(
            2.0 * self._max_constraint_val *
            (1. / (descent_direction.dot(Hx(descent_direction)) + 1e-8)))
        if np.isnan(initial_step_size):
            initial_step_size = 1.
        flat_descent_step = initial_step_size * descent_direction

        n_iter = 0
        for n_iter, ratio in enumerate(self._backtrack_ratio**np.arange(
                self._max_backtracks)):
            cur_step = ratio * flat_descent_step
            cur_param = prev_param - cur_step
            self._target.set_param_values(cur_param, trainable=True)
            constraint_val = sliced_fun(self._opt_fun["f_constraint"],
                                        self._num_slices)(inputs, extra_inputs)
            if self._debug_nan and np.isnan(constraint_val):
                import ipdb
                ipdb.set_trace()
            if constraint_val <= self._max_constraint_val:
                break
        if (np.isnan(constraint_val) or constraint_val >=
                self._max_constraint_val) and not self._accept_violation:
            logger.log("Line search condition violated. Rejecting the step!")
            if np.isnan(constraint_val):
                logger.log("Violated because constraint %s is NaN" %
                           self._constraint_name)
            if constraint_val >= self._max_constraint_val:
                logger.log("Violated because constraint %s is violated" %
                           self._constraint_name)
            self._target.set_param_values(prev_param, trainable=True)
        # logger.log("backtrack iters: %d" % n_iter)
        # logger.log("final magnitude: " + str(-ratio*initial_step_size))
        logger.log("final kl: " + str(constraint_val))
        # logger.log("optimization finished")
        return -ratio * initial_step_size, constraint_val