예제 #1
0
    def optimize(self,
                 inputs,
                 extra_inputs=None,
                 subsample_grouped_inputs=None,
                 name=None):
        with tf.name_scope(
                name,
                'optimize',
                values=[inputs, extra_inputs, subsample_grouped_inputs]):
            prev_param = np.copy(self._target.get_param_values(trainable=True))
            inputs = tuple(inputs)
            if extra_inputs is None:
                extra_inputs = tuple()

            if self._subsample_factor < 1:
                if subsample_grouped_inputs is None:
                    subsample_grouped_inputs = [inputs]
                subsample_inputs = tuple()
                for inputs_grouped in subsample_grouped_inputs:
                    n_samples = len(inputs_grouped[0])
                    inds = np.random.choice(n_samples,
                                            int(n_samples *
                                                self._subsample_factor),
                                            replace=False)
                    subsample_inputs += tuple(
                        [x[inds] for x in inputs_grouped])
            else:
                subsample_inputs = inputs

            logger.log(
                ('Start CG optimization: '
                 '#parameters: %d, #inputs: %d, #subsample_inputs: %d') %
                (len(prev_param), len(inputs[0]), len(subsample_inputs[0])))

            logger.log('computing loss before')
            loss_before = sliced_fun(self._opt_fun['f_loss'],
                                     self._num_slices)(inputs, extra_inputs)
            logger.log('performing update')

            logger.log('computing gradient')
            flat_g = sliced_fun(self._opt_fun['f_grad'],
                                self._num_slices)(inputs, extra_inputs)
            logger.log('gradient computed')

            logger.log('computing descent direction')
            hx = self._hvp_approach.build_eval(subsample_inputs + extra_inputs)
            descent_direction = krylov.cg(hx, flat_g, cg_iters=self._cg_iters)

            initial_step_size = np.sqrt(
                2.0 * self._max_constraint_val *
                (1. / (descent_direction.dot(hx(descent_direction)) + 1e-8)))
            if np.isnan(initial_step_size):
                initial_step_size = 1.
            flat_descent_step = initial_step_size * descent_direction

            logger.log('descent direction computed')

            n_iter = 0
            for n_iter, ratio in enumerate(self._backtrack_ratio**np.arange(
                    self._max_backtracks)):  # yapf: disable
                cur_step = ratio * flat_descent_step
                cur_param = prev_param - cur_step
                self._target.set_param_values(cur_param, trainable=True)
                loss, constraint_val = sliced_fun(
                    self._opt_fun['f_loss_constraint'],
                    self._num_slices)(inputs, extra_inputs)
                if self._debug_nan and np.isnan(constraint_val):
                    break
                if loss < loss_before and \
                   constraint_val <= self._max_constraint_val:
                    break
            if (np.isnan(loss) or np.isnan(constraint_val)
                    or loss >= loss_before or constraint_val >=
                    self._max_constraint_val) and not self._accept_violation:
                logger.log(
                    'Line search condition violated. Rejecting the step!')
                if np.isnan(loss):
                    logger.log('Violated because loss is NaN')
                if np.isnan(constraint_val):
                    logger.log('Violated because constraint %s is NaN' %
                               self._constraint_name)
                if loss >= loss_before:
                    logger.log('Violated because loss not improving')
                if constraint_val >= self._max_constraint_val:
                    logger.log('Violated because constraint %s is violated' %
                               self._constraint_name)
                self._target.set_param_values(prev_param, trainable=True)
            logger.log('backtrack iters: %d' % n_iter)
            logger.log('computing loss after')
            logger.log('optimization finished')
    def optimize(self,
                 inputs,
                 extra_inputs=None,
                 subsample_grouped_inputs=None):

        inputs = tuple(inputs)
        if extra_inputs is None:
            extra_inputs = tuple()

        if self._subsample_factor < 1:
            if subsample_grouped_inputs is None:
                subsample_grouped_inputs = [inputs]
            subsample_inputs = tuple()
            for inputs_grouped in subsample_grouped_inputs:
                n_samples = len(inputs_grouped[0])
                inds = np.random.choice(n_samples,
                                        int(n_samples *
                                            self._subsample_factor),
                                        replace=False)
                subsample_inputs += tuple([x[inds] for x in inputs_grouped])
        else:
            subsample_inputs = inputs

        logger.log("computing loss before")
        loss_before = sliced_fun(self._opt_fun["f_loss"],
                                 self._num_slices)(inputs, extra_inputs)
        logger.log("performing update")
        logger.log("computing descent direction")

        flat_g = sliced_fun(self._opt_fun["f_grad"],
                            self._num_slices)(inputs, extra_inputs)

        hx = self._hvp_approach.build_eval(subsample_inputs + extra_inputs)

        descent_direction = krylov.cg(hx, flat_g, cg_iters=self._cg_iters)

        initial_step_size = np.sqrt(
            2.0 * self._max_constraint_val *
            (1. / (descent_direction.dot(hx(descent_direction)) + 1e-8)))
        if np.isnan(initial_step_size):
            initial_step_size = 1.
        flat_descent_step = initial_step_size * descent_direction

        logger.log("descent direction computed")

        prev_param = np.copy(self._target.get_param_values(trainable=True))
        n_iter = 0
        for n_iter, ratio in enumerate(self._backtrack_ratio**np.arange(
                self._max_backtracks)):  # yapf: disable
            cur_step = ratio * flat_descent_step
            cur_param = prev_param - cur_step
            self._target.set_param_values(cur_param, trainable=True)
            loss, constraint_val = sliced_fun(
                self._opt_fun["f_loss_constraint"],
                self._num_slices)(inputs, extra_inputs)
            if loss < loss_before \
               and constraint_val <= self._max_constraint_val:
                break
        if (np.isnan(loss) or np.isnan(constraint_val) or loss >= loss_before
                or constraint_val >= self._max_constraint_val) \
           and not self._accept_violation:
            logger.log("Line search condition violated. Rejecting the step!")
            if np.isnan(loss):
                logger.log("Violated because loss is NaN")
            if np.isnan(constraint_val):
                logger.log("Violated because constraint %s is NaN" %
                           self._constraint_name)
            if loss >= loss_before:
                logger.log("Violated because loss not improving")
            if constraint_val >= self._max_constraint_val:
                logger.log("Violated because constraint %s is violated" %
                           self._constraint_name)
            self._target.set_param_values(prev_param, trainable=True)
        logger.log("backtrack iters: %d" % n_iter)
        logger.log("computing loss after")
        logger.log("optimization finished")
예제 #3
0
    def optimize(self,
                 inputs,
                 gradient,
                 extra_inputs=None,
                 subsample_grouped_inputs=None,
                 name=None):

        prev_param = np.copy(self._target.get_param_values(trainable=True))

        inputs = tuple(inputs)
        if extra_inputs is None:
            extra_inputs = tuple()

        sample_inputs = inputs

        logger.log(("Start CG optimization: "
                    "#parameters: %d, #inputs: %d, #subsample_inputs: %d") %
                   (len(prev_param), len(inputs[0]), len(sample_inputs[0])))

        logger.log("performing update")
        logger.log("computing descent direction")
        hx = self._hvp_approach.build_eval(sample_inputs)

        descent_direction = krylov.cg(hx, gradient, cg_iters=self._cg_iters)

        initial_step_size = np.sqrt(
            2.0 * self._max_constraint_val *
            (1. / (descent_direction.dot(hx(descent_direction)) + 1e-8)))
        if np.isnan(initial_step_size):
            initial_step_size = 1.
        flat_descent_step = initial_step_size * descent_direction

        logger.log("descent direction computed")

        n_iter = 0
        loss_before = self._opt_fun["f_loss"](*(inputs + extra_inputs))
        for n_iter, ratio in enumerate(self._backtrack_ratio**np.arange(
                self._max_backtracks)):  # yapf: disable
            cur_step = ratio * flat_descent_step
            cur_param = prev_param - cur_step
            self._target.set_param_values(cur_param, trainable=True)
            loss, constraint_val = sliced_fun(
                self._opt_fun["f_loss_constraint"],
                self._num_slices)(inputs + extra_inputs)
            if self._debug_nan and np.isnan(constraint_val):
                break
            if loss < loss_before and \
               constraint_val <= self._max_constraint_val:
                break
        if (np.isnan(loss) or np.isnan(constraint_val) or loss > loss_before
                or constraint_val >= self._max_constraint_val
            ) and not self._accept_violation:
            logger.log("Line search condition violated. Rejecting the step!")
            if np.isnan(loss):
                logger.log("Violated because loss is NaN")
            if np.isnan(constraint_val):
                logger.log("Violated because constraint %s is NaN" %
                           self._constraint_name)
            if loss > loss_before:
                logger.log("Violated because loss not improving")
            if constraint_val >= self._max_constraint_val:
                logger.log("Violated because constraint %s is violated" %
                           self._constraint_name)
            self._target.set_param_values(prev_param, trainable=True)
        logger.log("backtrack iters: %d" % n_iter)
        logger.log("computing loss after")
        logger.log("optimization finished")