Beispiel #1
0
    def optimize(self, inputs, extra_inputs=None, callback=None):

        if len(inputs) == 0:
            # Assumes that we should always sample mini-batches
            raise NotImplementedError

        f_loss = self._opt_fun["f_loss"]

        if extra_inputs is None:
            extra_inputs = tuple()

        last_loss = sliced_fun(f_loss, self._num_slices)(inputs, extra_inputs)
        #last_loss = f_loss(*(tuple(inputs) + extra_inputs))

        start_time = time.time()

        dataset = BatchDataset(inputs, self._batch_size, extra_inputs=extra_inputs)

        sess = tf.get_default_session()

        for epoch in range(self._max_epochs):
            if self._verbose:
                logger.log("Epoch %d" % (epoch))
                progbar = pyprind.ProgBar(len(inputs[0]))

            for batch in dataset.iterate(update=True):
                if (self._ignore_last and len(batch[0]) != self._batch_size):
                    continue
                sess.run(self._train_op, dict(list(zip(self._input_vars, batch))))
                if self._verbose:
                    progbar.update(len(batch[0]))

            if self._verbose:
                if progbar.active:
                    progbar.stop()

            new_loss = sliced_fun(f_loss, self._num_slices)(inputs, extra_inputs)
            #new_loss = f_loss(*(tuple(inputs) + extra_inputs))

            if self._verbose:
                logger.log("Epoch: %d | Loss: %f" % (epoch, new_loss))
            if self._callback or callback:
                elapsed = time.time() - start_time
                callback_args = dict(
                    loss=new_loss,
                    params=self._target.get_param_values(trainable=True) if self._target else None,
                    itr=epoch,
                    elapsed=elapsed,
                )
                if self._callback:
                    self._callback(callback_args)
                if callback:
                    callback(**callback_args)

            if abs(last_loss - new_loss) < self._tolerance:
                break
            last_loss = new_loss
 def eval(x):
     if config.TF_NN_SETTRACE:
         ipdb.set_trace()
     xs = tuple(self.target.flat_to_params(x, trainable=True))
     ret = sliced_fun(self.opt_fun["f_Hx_plain"], self._num_slices)(
         inputs, xs) + self.reg_coeff * x
     return ret
Beispiel #3
0
    def line_search(self, descent_step, inputs, extra_inputs=()):
        f_loss = self._opt_fun["f_loss"]
        f_loss_constraint = self._opt_fun["f_loss_constraint"]
        prev_w = np.copy(self._target.get_param_values(trainable=True))
        loss_before = f_loss(*(inputs + extra_inputs))
        n_iter = 0
        succ_line_search = False
        for n_iter, ratio in enumerate(
                self._backtrack_ratio ** np.arange(self._max_backtracks)):
            cur_step = ratio * descent_step
            cur_w = prev_w - cur_step
            self._target.set_param_values(cur_w, trainable=True)
            loss, constraint_val = sliced_fun(f_loss_constraint,
                                              self._num_slices)(inputs, extra_inputs)
            if loss < loss_before and constraint_val <= self._max_constraint_val:
                succ_line_search = True
                break

        if (np.isnan(loss) or np.isnan(constraint_val) or loss >=
                loss_before or constraint_val >= self._max_constraint_val):
            logger.log("Line search condition violated. Rejecting the step!")
            if np.isnan(loss):
                logger.log("Violated because loss is NaN")
            if np.isnan(constraint_val):
                logger.log("Violated because constraint is NaN")
            if loss >= loss_before:
                logger.log("Violated because loss not improving")
            if constraint_val >= self._max_constraint_val:
                logger.log(
                    "Violated because constraint {:} is violated".format(constraint_val))

            self._target.set_param_values(prev_w, trainable=True)

        logger.log("backtrack iters: %d" % n_iter)
        return succ_line_search
Beispiel #4
0
        def line_search(check_loss=True, check_quad=True, check_lin=True):
            loss_rejects = 0
            quad_rejects = 0
            lin_rejects  = 0
            n_iter = 0
            for n_iter, ratio in enumerate(self._backtrack_ratio ** np.arange(self._max_backtracks)):
                cur_step = ratio * flat_descent_step
                cur_param = prev_param - cur_step
                self._target.set_param_values(cur_param, trainable=True)
                loss, quad_constraint_val, lin_constraint_val = sliced_fun(
                    self._opt_fun["f_loss_constraint"], self._num_slices)(inputs, extra_inputs)
                loss_flag = loss < loss_before
                quad_flag = quad_constraint_val <= self._max_quad_constraint_val
                lin_flag  = lin_constraint_val  <= lin_reject_threshold
                if check_loss and not(loss_flag):
                    loss_rejects += 1
                if check_quad and not(quad_flag):
                    quad_rejects += 1
                if check_lin and not(lin_flag):
                    lin_rejects += 1

                if (loss_flag or not(check_loss)) and (quad_flag or not(check_quad)) and (lin_flag or not(check_lin)):
                    break

            return loss, quad_constraint_val, lin_constraint_val, n_iter
Beispiel #5
0
    def mean_kl(self, samples_data):
        all_input_values = self.construct_inputs(samples_data)

        kl_divs = []
        for constraint in self.f_constraints:
            kl_divs.append(sliced_fun(constraint, 1)(all_input_values))
        return kl_divs
 def constraint_val(self, inputs, extra_inputs=None):
     if config.TF_NN_SETTRACE:
         ipdb.set_trace()
     inputs = tuple(inputs)
     if extra_inputs is None:
         extra_inputs = tuple()
     return sliced_fun(self._opt_fun["f_constraint"],
                       self._num_slices)(inputs, extra_inputs)
 def _constraint_val(self, inputs, extra_inputs):
     """
     Parallelized: returns the same value in all workers.
     """
     shareds, barriers = self._par_objs
     shareds.constraint_val[self.rank] = self.avg_fac * sliced_fun(
         self._opt_fun["f_constraint"], self._num_slices)(inputs, extra_inputs)
     barriers.cnstr.wait()
     return sum(shareds.constraint_val)
def get_gradient(algo, samples_data, flat=False):
    all_input_values = tuple(
        ext.extract(samples_data, "observations", "actions", "advantages"))
    agent_infos = samples_data["agent_infos"]
    state_info_list = [agent_infos[k] for k in algo.policy.state_info_keys]
    dist_info_list = [
        agent_infos[k] for k in algo.policy.distribution.dist_info_keys
    ]
    all_input_values += tuple(state_info_list) + tuple(dist_info_list)

    if flat:
        grad = sliced_fun(algo.optimizer._opt_fun["f_grad"],
                          1)(tuple(all_input_values), tuple())
    else:
        grad = sliced_fun(algo.optimizer._opt_fun["f_grads"],
                          1)(tuple(all_input_values), tuple())

    return grad
Beispiel #9
0
 def flat_g(self, inputs, extra_inputs=None):
     shareds, barriers = self._par_objs
     # Each worker records result available to all.
     shareds.grads_2d[:, self.rank] = self.avg_fac * ext.sliced_fun(
         self._opt_fun["f_grad"], self._num_slices)(inputs, extra_inputs)
     barriers.flat_g[0].wait()
     if self.rank == 0:
         shareds.flat_g = np.sum(shareds.grads_2d, axis=1)
     barriers.flat_g[1].wait()
     return shareds.flat_g
 def _loss_constraint(self, inputs, extra_inputs):
     """
     Parallelized: returns the same values in all workers.
     """
     shareds, barriers = self._par_objs
     loss, constraint_val = sliced_fun(self._opt_fun["f_loss_constraint"],
         self._num_slices)(inputs, extra_inputs)
     shareds.loss[self.rank] = self.avg_fac * loss
     shareds.constraint_val[self.rank] = self.avg_fac * constraint_val
     barriers.loss_cnstr.wait()
     return sum(shareds.loss), sum(shareds.constraint_val)
Beispiel #11
0
 def check_nan():
     loss, quad_constraint_val, lin_constraint_val = sliced_fun(
         self._opt_fun["f_loss_constraint"], self._num_slices)(inputs, extra_inputs)
     if np.isnan(loss) or np.isnan(quad_constraint_val) or np.isnan(lin_constraint_val):
         if np.isnan(loss):
             logger.log("Violated because loss is NaN")
         if np.isnan(quad_constraint_val):
             logger.log("Violated because quad_constraint %s is NaN" %
                        self._constraint_name_1)
         if np.isnan(lin_constraint_val):
             logger.log("Violated because lin_constraint %s is NaN" %
                        self._constraint_name_2)
         self._target.set_param_values(prev_param, trainable=True)
Beispiel #12
0
 def get_grad(self, samples_data):
     all_input_values = tuple(
         ext.extract(samples_data, "observations", "actions", "advantages"))
     agent_infos = samples_data["agent_infos"]
     state_info_list = [agent_infos[k] for k in self.policy.state_info_keys]
     dist_info_list = [
         agent_infos[k] for k in self.policy.distribution.dist_info_keys
     ]
     all_input_values += tuple(state_info_list) + tuple(dist_info_list)
     if self.policy.recurrent:
         all_input_values += (samples_data["valids"], )
     return sliced_fun(self.optimizer._opt_fun["f_grads"],
                       1)((all_input_values))
Beispiel #13
0
        def line_search(check_loss=True, check_quad=True, check_lin=True):
            loss_rejects = 0
            quad_rejects = 0
            lin_rejects = 0
            n_iter = 0
            for n_iter, ratio in enumerate(self._backtrack_ratio**np.arange(
                    self._max_backtracks)):
                cur_step = ratio * flat_descent_step
                cur_param = prev_param - cur_step
                self._target.set_param_values(cur_param, trainable=True)
                loss, quad_constraint_val, lin_constraint_val = sliced_fun(
                    self._opt_fun["f_loss_constraint"],
                    self._num_slices)(inputs, extra_inputs)
                loss_flag = loss < loss_before
                quad_flag = quad_constraint_val <= self._max_quad_constraint_val
                lin_flag = lin_constraint_val <= lin_reject_threshold
                if check_loss and not (loss_flag):
                    logger.log("At backtrack itr %i, loss failed to improve." %
                               n_iter)
                    loss_rejects += 1
                if check_quad and not (quad_flag):
                    logger.log(
                        "At backtrack itr %i, quad constraint violated." %
                        n_iter)
                    logger.log(
                        "Quad constraint violation was %.3f %%." %
                        (100 *
                         (quad_constraint_val / self._max_quad_constraint_val)
                         - 100))
                    quad_rejects += 1
                if check_lin and not (lin_flag):
                    logger.log(
                        "At backtrack itr %i, expression for lin constraint failed to improve."
                        % n_iter)
                    logger.log(
                        "Lin constraint violation was %.3f %%." %
                        (100 *
                         (lin_constraint_val / lin_reject_threshold) - 100))
                    lin_rejects += 1

                if (loss_flag or not (check_loss)) and (
                        quad_flag
                        or not (check_quad)) and (lin_flag or not (check_lin)):
                    logger.log("Accepted step at backtrack itr %i." % n_iter)
                    break

            logger.record_tabular("BacktrackIters", n_iter)
            logger.record_tabular("LossRejects", loss_rejects)
            logger.record_tabular("QuadRejects", quad_rejects)
            logger.record_tabular("LinRejects", lin_rejects)
            return loss, quad_constraint_val, lin_constraint_val, n_iter
 def _flat_g(self, inputs, extra_inputs):
     """
     Parallelized: returns the same values in all workers.
     """
     shareds, barriers = self._par_objs
     # Each worker records result available to all.
     shareds.grads_2d[:, self.rank] = self.avg_fac * \
         sliced_fun(self._opt_fun["f_grad"], self._num_slices)(inputs, extra_inputs)
     barriers.flat_g[0].wait()
     # Each worker sums over an equal share of the grad elements across
     # workers (row major storage--sum along rows).
     shareds.flat_g[self.vb[0]:self.vb[1]] = \
         np.sum(shareds.grads_2d[self.vb[0]:self.vb[1], :], axis=1)
     barriers.flat_g[1].wait()
Beispiel #15
0
    def get_gradient(self, samples_data):
        all_input_values = tuple(
            ext.extract(samples_data, "observations", "actions", "advantages"))

        agent_infos = samples_data["agent_infos"]
        state_info_list = [agent_infos[k] for k in self.policy.state_info_keys]
        dist_info_list = [
            agent_infos[k] for k in self.policy.distribution.dist_info_keys
        ]

        all_input_values += tuple(state_info_list) + tuple(dist_info_list)
        if self.policy.recurrent:
            all_input_values += (samples_data["valids"], )

        # multitask related
        task_obs = []
        task_old_dist_info_list = []
        task_old_dist_info = []
        for i in range(self.task_num):
            task_obs.append([])
            task_old_dist_info_list.append([])
            task_old_dist_info.append([])
            for k in self.policy.distribution.dist_info_keys:
                task_old_dist_info_list[i].append([])
        for i in range(len(samples_data["observations"])):
            taskid = np.random.randint(
                self.task_num
            )  # fake the taskid to satisfy the calculation requirement, very ugly
            task_obs[taskid].append(samples_data["observations"][i])
            for j, k in enumerate(self.policy.distribution.dist_info_keys):
                task_old_dist_info_list[taskid][j].append(
                    samples_data["agent_infos"][k][i])
        for i in range(self.task_num):
            for j, k in enumerate(self.policy.distribution.dist_info_keys):
                task_old_dist_info[i].append(
                    np.array(task_old_dist_info_list[i][j]))
            task_obs[i] = np.array(task_obs[i])

        for i in range(self.task_num):
            all_input_values += tuple([task_obs[i]])
        for i in range(self.task_num):
            all_input_values += tuple(task_old_dist_info[i])
        all_input_values += tuple([self.kl_weights])

        grad = sliced_fun(self.optimizer._opt_fun["f_grads"],
                          1)((all_input_values))

        return grad
Beispiel #16
0
    def mean_kl(self, samples_data):
        all_input_values = tuple(
            ext.extract(samples_data, "observations", "actions", "advantages"))

        agent_infos = samples_data["agent_infos"]
        state_info_list = [agent_infos[k] for k in self.policy.state_info_keys]
        dist_info_list = [
            agent_infos[k] for k in self.policy.distribution.dist_info_keys
        ]

        all_input_values += tuple(state_info_list) + tuple(dist_info_list)
        if self.policy.recurrent:
            all_input_values += (samples_data["valids"], )

        # multitask related
        task_obs = []
        task_old_dist_info_list = []
        task_old_dist_info = []
        for i in range(self.task_num):
            task_obs.append([])
            task_old_dist_info_list.append([])
            task_old_dist_info.append([])
            for k in self.policy.distribution.dist_info_keys:
                task_old_dist_info_list[i].append([])
        for i in range(len(samples_data["observations"])):
            taskid = samples_data["env_infos"]["state_index"][i]
            task_obs[taskid].append(samples_data["observations"][i])
            for j, k in enumerate(self.policy.distribution.dist_info_keys):
                task_old_dist_info_list[taskid][j].append(
                    samples_data["agent_infos"][k][i])
        for i in range(self.task_num):
            for j, k in enumerate(self.policy.distribution.dist_info_keys):
                task_old_dist_info[i].append(
                    np.array(task_old_dist_info_list[i][j]))
            task_obs[i] = np.array(task_obs[i])

        for i in range(self.task_num):
            all_input_values += tuple([task_obs[i]])
        for i in range(self.task_num):
            all_input_values += tuple(task_old_dist_info[i])
        all_input_values += tuple([self.kl_weights])

        kl_divs = []
        for constraint in self.f_constraints:
            kl_divs.append(sliced_fun(constraint, 1)(all_input_values))
        return kl_divs
        def parallel_eval(x):
            """
            Parallelized.
            """
            shareds, barriers = self._par_objs

            xs = tuple(self.target.flat_to_params(x, trainable=True))

            shareds.grads_2d[:, self.pd.rank] = self.pd.avg_fac * \
                sliced_fun(self.opt_fun["f_Hx_plain"], self._num_slices)(inputs, xs)
            barriers.Hx[0].wait()

            shareds.Hx[self.pd.vb[0]:self.pd.vb[1]] = \
                self.reg_coeff * x[self.pd.vb[0]:self.pd.vb[1]] + \
                np.sum(shareds.grads_2d[self.pd.vb[0]:self.pd.vb[1], :], axis=1)
            barriers.Hx[1].wait()
            return shareds.Hx  # (or can just access this persistent var elsewhere)
Beispiel #18
0
    def get_gradient(self, samples_data):
        all_input_values = tuple(ext.extract(
            samples_data,
            "observations", "actions", "advantages"
        ))

        agent_infos = samples_data["agent_infos"]

        task_id = samples_data["env_infos"]["state_index"][-1]

        state_info_list = [agent_infos[k] for k in self.policy.state_info_keys]
        dist_info_list = [agent_infos[k] for k in self.policy.distribution.dist_info_keys]

        task_obs = []
        task_actions = []
        task_advantages = []
        task_old_dist_info_list = []
        task_old_dist_info = []
        for i in range(self.task_num):
            task_obs.append([])
            task_actions.append([])
            task_advantages.append([])
            task_old_dist_info_list.append([])
            task_old_dist_info.append([])
            for k in self.policy.distribution.dist_info_keys:
                task_old_dist_info_list[i].append([])
        for i in range(len(samples_data["observations"])):
            taskid = samples_data["env_infos"]["state_index"][i]
            task_obs[taskid].append(samples_data["observations"][i])
            task_actions[taskid].append(samples_data["actions"][i])
            task_advantages[taskid].append(samples_data["advantages"][i])
            for j, k in enumerate(self.policy.distribution.dist_info_keys):
                task_old_dist_info_list[taskid][j].append(samples_data["agent_infos"][k][i])
        for i in range(self.task_num):
            for j, k in enumerate(self.policy.distribution.dist_info_keys):
                task_old_dist_info[i].append(np.array(task_old_dist_info_list[i][j]))
            task_obs[i] = np.array(task_obs[i])

        input_values = tuple([task_obs[task_id]]) + tuple([task_actions[task_id]]) + tuple([task_advantages[task_id]]) + tuple(task_old_dist_info[task_id]) + tuple(state_info_list)

        grad = sliced_fun(self.f_task_grads[task_id], 1)(
            (input_values))

        return grad
Beispiel #19
0
        def wrap_up():
            if optim_case < 4:
                lin_constraint_val = sliced_fun(
                    self._opt_fun["f_lin_constraint"],
                    self._num_slices)(inputs, extra_inputs)
                lin_constraint_delta = lin_constraint_val - prev_lin_constraint_val
                logger.record_tabular("LinConstraintDelta",
                                      lin_constraint_delta)

                cur_param = self._target.get_param_values()

                next_linear_S = S + flat_b.dot(cur_param - prev_param)
                next_surrogate_S = S + lin_constraint_delta

                lin_surrogate_acc = 100. * (
                    next_linear_S - next_surrogate_S) / next_surrogate_S

                logger.record_tabular("PredictedLinearS", next_linear_S)
                logger.record_tabular("PredictedSurrogateS", next_surrogate_S)
                logger.record_tabular("LinearSurrogateErr", lin_surrogate_acc)

                lin_pred_err = (self._last_lin_pred_S - S)  #/ (S + eps)
                surr_pred_err = (self._last_surr_pred_S - S)  #/ (S + eps)
                logger.record_tabular("PredictionErrorLinearS", lin_pred_err)
                logger.record_tabular("PredictionErrorSurrogateS",
                                      surr_pred_err)
                self._last_lin_pred_S = next_linear_S
                self._last_surr_pred_S = next_surrogate_S

            else:
                logger.record_tabular("LinConstraintDelta", 0)
                logger.record_tabular("PredictedLinearS", 0)
                logger.record_tabular("PredictedSurrogateS", 0)
                logger.record_tabular("LinearSurrogateErr", 0)

                lin_pred_err = (self._last_lin_pred_S - 0)  #/ (S + eps)
                surr_pred_err = (self._last_surr_pred_S - 0)  #/ (S + eps)
                logger.record_tabular("PredictionErrorLinearS", lin_pred_err)
                logger.record_tabular("PredictionErrorSurrogateS",
                                      surr_pred_err)
                self._last_lin_pred_S = 0
                self._last_surr_pred_S = 0
Beispiel #20
0
    def optimize_expert_policies(self, itr, all_samples_data):

        dist_info_keys = self.policy.distribution.dist_info_keys
        for n, optimizer in enumerate(self.optimizers):

            obs_act_adv_values = tuple(
                ext.extract(all_samples_data[n], "observations", "actions",
                            "advantages"))
            dist_info_list = tuple([
                all_samples_data[n]["agent_infos"][k] for k in dist_info_keys
            ])
            all_task_obs_values = tuple([
                samples_data["observations"]
                for samples_data in all_samples_data
            ])

            all_input_values = obs_act_adv_values + dist_info_list + all_task_obs_values + all_task_obs_values
            optimizer.optimize(all_input_values)

            kl_penalty = sliced_fun(self.metrics[n], 1)(all_input_values)
Beispiel #21
0
 def eval(x):
     xs = tuple(self.target.flat_to_params(x, trainable=True))
     ret = sliced_fun(self.opt_fun["f_Hx_plain"], self._num_slices)(
         inputs, xs) + self.reg_coeff * x
     return ret
Beispiel #22
0
    def optimize(
        self,
        inputs,
        extra_inputs=None,
        subsample_grouped_inputs=None,
        precomputed_eval=None,
        precomputed_threshold=None,
        diff_threshold=False,
        inputs2=None,
        extra_inputs2=None,
    ):
        """
        precomputed_eval         :  The value of the safety constraint at theta = theta_old. 
                                    Provide this when the lin_constraint function is a surrogate, and evaluating it at 
                                    theta_old will not give you the correct value.

        precomputed_threshold &
        diff_threshold           :  These relate to the linesearch that is used to ensure constraint satisfaction.
                                    If the lin_constraint function is indeed the safety constraint function, then it 
                                    suffices to check that lin_constraint < max_lin_constraint_val to ensure satisfaction.
                                    But if the lin_constraint function is a surrogate - ie, it only has the same
                                    /gradient/ as the safety constraint - then the threshold we check it against has to
                                    be adjusted. You can provide a fixed adjusted threshold via "precomputed_threshold."
                                    When "diff_threshold" == True, instead of checking
                                        lin_constraint < threshold,
                                    it will check
                                        lin_constraint - old_lin_constraint < threshold.
        """

        inputs = tuple(inputs)
        if extra_inputs is None:
            extra_inputs = tuple()

        # inputs2 and extra_inputs2 are for calculation of the linearized constraint.
        # This functionality - of having separate inputs for that constraint - is
        # intended to allow a "learning without forgetting" setup.
        if inputs2 is None:
            inputs2 = inputs
        if extra_inputs2 is None:
            extra_inputs2 = tuple()

        def subsampled_inputs(inputs, subsample_grouped_inputs):
            if self._subsample_factor < 1:
                if subsample_grouped_inputs is None:
                    subsample_grouped_inputs = [inputs]
                subsample_inputs = tuple()
                for inputs_grouped in subsample_grouped_inputs:
                    n_samples = len(inputs_grouped[0])
                    inds = np.random.choice(n_samples,
                                            int(n_samples *
                                                self._subsample_factor),
                                            replace=False)
                    subsample_inputs += tuple(
                        [x[inds] for x in inputs_grouped])
            else:
                subsample_inputs = inputs
            return subsample_inputs

        subsample_inputs = subsampled_inputs(inputs, subsample_grouped_inputs)
        if self._resample_inputs:
            subsample_inputs2 = subsampled_inputs(inputs,
                                                  subsample_grouped_inputs)

        logger.log("computing loss before")
        loss_before = sliced_fun(self._opt_fun["f_loss"],
                                 self._num_slices)(inputs, extra_inputs)
        logger.log("performing update")
        logger.log("computing descent direction")

        flat_g = sliced_fun(self._opt_fun["f_grad"],
                            self._num_slices)(inputs, extra_inputs)
        flat_b = sliced_fun(self._opt_fun["f_lin_constraint_grad"],
                            self._num_slices)(inputs2, extra_inputs2)

        Hx = self._hvp_approach.build_eval(subsample_inputs + extra_inputs)
        v = krylov.cg(Hx,
                      flat_g,
                      cg_iters=self._cg_iters,
                      verbose=self._verbose_cg)

        approx_g = Hx(v)
        q = v.dot(approx_g)  # approx = g^T H^{-1} g
        delta = 2 * self._max_quad_constraint_val

        eps = 1e-8

        residual = np.sqrt((approx_g - flat_g).dot(approx_g - flat_g))
        rescale = q / (v.dot(v))
        logger.record_tabular("OptimDiagnostic_Residual", residual)
        logger.record_tabular("OptimDiagnostic_Rescale", rescale)

        if self.precompute:
            S = precomputed_eval
            assert (np.ndim(S) == 0)  # please be a scalar
        else:
            S = sliced_fun(self._opt_fun["lin_constraint"],
                           self._num_slices)(inputs, extra_inputs)

        c = S - self._max_lin_constraint_val
        if c > 0:
            logger.log("warning! safety constraint is already violated")
        else:
            # the current parameters constitute a feasible point: save it as "last good point"
            self.last_safe_point = np.copy(
                self._target.get_param_values(trainable=True))

        # can't stop won't stop (unless something in the conditional checks / calculations that follow
        # require premature stopping of optimization process)
        stop_flag = False

        if flat_b.dot(flat_b) <= eps:
            # if safety gradient is zero, linear constraint is not present;
            # ignore its implementation.
            lam = np.sqrt(q / delta)
            nu = 0
            w = 0
            r, s, A, B = 0, 0, 0, 0
            optim_case = 4
        else:
            if self._resample_inputs:
                Hx = self._hvp_approach.build_eval(subsample_inputs2 +
                                                   extra_inputs)

            norm_b = np.sqrt(flat_b.dot(flat_b))
            unit_b = flat_b / norm_b
            w = norm_b * krylov.cg(
                Hx, unit_b, cg_iters=self._cg_iters, verbose=self._verbose_cg)

            r = w.dot(approx_g)  # approx = b^T H^{-1} g
            s = w.dot(Hx(w))  # approx = b^T H^{-1} b

            # figure out lambda coeff (lagrange multiplier for trust region)
            # and nu coeff (lagrange multiplier for linear constraint)
            A = q - r**2 / s  # this should always be positive by Cauchy-Schwarz
            B = delta - c**2 / s  # this one says whether or not the closest point on the plane is feasible

            # if (B < 0), that means the trust region plane doesn't intersect the safety boundary

            if c < 0 and B < 0:
                # point in trust region is feasible and safety boundary doesn't intersect
                # ==> entire trust region is feasible
                optim_case = 3
            elif c < 0 and B > 0:
                # x = 0 is feasible and safety boundary intersects
                # ==> most of trust region is feasible
                optim_case = 2
            elif c > 0 and B > 0:
                # x = 0 is infeasible (bad! unsafe!) and safety boundary intersects
                # ==> part of trust region is feasible
                # ==> this is 'recovery mode'
                optim_case = 1
                if self.attempt_feasible_recovery:
                    logger.log(
                        "alert! conjugate constraint optimizer is attempting feasible recovery"
                    )
                else:
                    logger.log(
                        "alert! problem is feasible but needs recovery, and we were instructed not to attempt recovery"
                    )
                    stop_flag = True
            else:
                # x = 0 infeasible (bad! unsafe!) and safety boundary doesn't intersect
                # ==> whole trust region infeasible
                # ==> optimization problem infeasible!!!
                optim_case = 0
                if self.attempt_infeasible_recovery:
                    logger.log(
                        "alert! conjugate constraint optimizer is attempting infeasible recovery"
                    )
                else:
                    logger.log(
                        "alert! problem is infeasible, and we were instructed not to attempt recovery"
                    )
                    stop_flag = True

            # default dual vars, which assume safety constraint inactive
            # (this corresponds to either optim_case == 3,
            #  or optim_case == 2 under certain conditions)
            lam = np.sqrt(q / delta)
            nu = 0

            if optim_case == 2 or optim_case == 1:

                # dual function is piecewise continuous
                # on region (a):
                #
                #   L(lam) = -1/2 (A / lam + B * lam) - r * c / s
                #
                # on region (b):
                #
                #   L(lam) = -1/2 (q / lam + delta * lam)
                #

                lam_mid = r / c
                L_mid = -0.5 * (q / lam_mid + lam_mid * delta)

                lam_a = np.sqrt(A / (B + eps))
                L_a = -np.sqrt(A * B) - r * c / (s + eps)
                # note that for optim_case == 1 or 2, B > 0, so this calculation should never be an issue

                lam_b = np.sqrt(q / delta)
                L_b = -np.sqrt(q * delta)

                #those lam's are solns to the pieces of piecewise continuous dual function.
                #the domains of the pieces depend on whether or not c < 0 (x=0 feasible),
                #and so projection back on to those domains is determined appropriately.
                if lam_mid > 0:
                    if c < 0:
                        # here, domain of (a) is [0, lam_mid)
                        # and domain of (b) is (lam_mid, infty)
                        if lam_a > lam_mid:
                            lam_a = lam_mid
                            L_a = L_mid
                        if lam_b < lam_mid:
                            lam_b = lam_mid
                            L_b = L_mid
                    else:
                        # here, domain of (a) is (lam_mid, infty)
                        # and domain of (b) is [0, lam_mid)
                        if lam_a < lam_mid:
                            lam_a = lam_mid
                            L_a = L_mid
                        if lam_b > lam_mid:
                            lam_b = lam_mid
                            L_b = L_mid

                    if L_a >= L_b:
                        lam = lam_a
                    else:
                        lam = lam_b

                else:
                    if c < 0:
                        lam = lam_b
                    else:
                        lam = lam_a

                nu = max(0, lam * c - r) / (s + eps)

        logger.record_tabular(
            "OptimCase",
            optim_case)  # 4 / 3: trust region totally in safe region;
        # 2 : trust region partly intersects safe region, and current point is feasible
        # 1 : trust region partly intersects safe region, and current point is infeasible
        # 0 : trust region does not intersect safe region
        logger.record_tabular("LagrangeLamda",
                              lam)  # dual variable for trust region
        logger.record_tabular("LagrangeNu",
                              nu)  # dual variable for safety constraint
        logger.record_tabular("OptimDiagnostic_q", q)  # approx = g^T H^{-1} g
        logger.record_tabular("OptimDiagnostic_r", r)  # approx = b^T H^{-1} g
        logger.record_tabular("OptimDiagnostic_s", s)  # approx = b^T H^{-1} b
        logger.record_tabular("OptimDiagnostic_c",
                              c)  # if > 0, constraint is violated
        logger.record_tabular("OptimDiagnostic_A", A)
        logger.record_tabular("OptimDiagnostic_B", B)
        logger.record_tabular("OptimDiagnostic_S", S)
        if nu == 0:
            logger.log("safety constraint is not active!")

        # Predict worst-case next S
        nextS = S + np.sqrt(delta * s)
        logger.record_tabular("OptimDiagnostic_WorstNextS", nextS)

        # for cases where we will not attempt recovery, we stop here. we didn't stop earlier
        # because first we wanted to record the various critical quantities for understanding the failure mode
        # (such as optim_case, B, c, S). Also, the logger gets angry if you are inconsistent about recording
        # a given quantity from iteration to iteration. That's why we have to record a BacktrackIters here.
        def record_zeros():
            logger.record_tabular("BacktrackIters", 0)
            logger.record_tabular("LossRejects", 0)
            logger.record_tabular("QuadRejects", 0)
            logger.record_tabular("LinRejects", 0)

        if optim_case > 0:
            flat_descent_step = (1. / (lam + eps)) * (v + nu * w)
        else:
            # current default behavior for attempting infeasible recovery:
            # take a step on natural safety gradient
            flat_descent_step = np.sqrt(delta / (s + eps)) * w

        logger.log("descent direction computed")

        prev_param = np.copy(self._target.get_param_values(trainable=True))

        prev_lin_constraint_val = sliced_fun(self._opt_fun["f_lin_constraint"],
                                             self._num_slices)(inputs,
                                                               extra_inputs)
        logger.record_tabular("PrevLinConstVal", prev_lin_constraint_val)

        lin_reject_threshold = self._max_lin_constraint_val
        if precomputed_threshold is not None:
            lin_reject_threshold = precomputed_threshold
        if diff_threshold:
            lin_reject_threshold += prev_lin_constraint_val
        logger.record_tabular("LinRejectThreshold", lin_reject_threshold)

        def check_nan():
            loss, quad_constraint_val, lin_constraint_val = sliced_fun(
                self._opt_fun["f_loss_constraint"],
                self._num_slices)(inputs, extra_inputs)
            if np.isnan(loss) or np.isnan(quad_constraint_val) or np.isnan(
                    lin_constraint_val):
                logger.log("Something is NaN. Rejecting the step!")
                if np.isnan(loss):
                    logger.log("Violated because loss is NaN")
                if np.isnan(quad_constraint_val):
                    logger.log("Violated because quad_constraint %s is NaN" %
                               self._constraint_name_1)
                if np.isnan(lin_constraint_val):
                    logger.log("Violated because lin_constraint %s is NaN" %
                               self._constraint_name_2)
                self._target.set_param_values(prev_param, trainable=True)

        def line_search(check_loss=True, check_quad=True, check_lin=True):
            loss_rejects = 0
            quad_rejects = 0
            lin_rejects = 0
            n_iter = 0
            for n_iter, ratio in enumerate(self._backtrack_ratio**np.arange(
                    self._max_backtracks)):
                cur_step = ratio * flat_descent_step
                cur_param = prev_param - cur_step
                self._target.set_param_values(cur_param, trainable=True)
                loss, quad_constraint_val, lin_constraint_val = sliced_fun(
                    self._opt_fun["f_loss_constraint"],
                    self._num_slices)(inputs, extra_inputs)
                loss_flag = loss < loss_before
                quad_flag = quad_constraint_val <= self._max_quad_constraint_val
                lin_flag = lin_constraint_val <= lin_reject_threshold
                if check_loss and not (loss_flag):
                    logger.log("At backtrack itr %i, loss failed to improve." %
                               n_iter)
                    loss_rejects += 1
                if check_quad and not (quad_flag):
                    logger.log(
                        "At backtrack itr %i, quad constraint violated." %
                        n_iter)
                    logger.log(
                        "Quad constraint violation was %.3f %%." %
                        (100 *
                         (quad_constraint_val / self._max_quad_constraint_val)
                         - 100))
                    quad_rejects += 1
                if check_lin and not (lin_flag):
                    logger.log(
                        "At backtrack itr %i, expression for lin constraint failed to improve."
                        % n_iter)
                    logger.log(
                        "Lin constraint violation was %.3f %%." %
                        (100 *
                         (lin_constraint_val / lin_reject_threshold) - 100))
                    lin_rejects += 1

                if (loss_flag or not (check_loss)) and (
                        quad_flag
                        or not (check_quad)) and (lin_flag or not (check_lin)):
                    logger.log("Accepted step at backtrack itr %i." % n_iter)
                    break

            logger.record_tabular("BacktrackIters", n_iter)
            logger.record_tabular("LossRejects", loss_rejects)
            logger.record_tabular("QuadRejects", quad_rejects)
            logger.record_tabular("LinRejects", lin_rejects)
            return loss, quad_constraint_val, lin_constraint_val, n_iter

        def wrap_up():
            if optim_case < 4:
                lin_constraint_val = sliced_fun(
                    self._opt_fun["f_lin_constraint"],
                    self._num_slices)(inputs, extra_inputs)
                lin_constraint_delta = lin_constraint_val - prev_lin_constraint_val
                logger.record_tabular("LinConstraintDelta",
                                      lin_constraint_delta)

                cur_param = self._target.get_param_values()

                next_linear_S = S + flat_b.dot(cur_param - prev_param)
                next_surrogate_S = S + lin_constraint_delta

                lin_surrogate_acc = 100. * (
                    next_linear_S - next_surrogate_S) / next_surrogate_S

                logger.record_tabular("PredictedLinearS", next_linear_S)
                logger.record_tabular("PredictedSurrogateS", next_surrogate_S)
                logger.record_tabular("LinearSurrogateErr", lin_surrogate_acc)

                lin_pred_err = (self._last_lin_pred_S - S)  #/ (S + eps)
                surr_pred_err = (self._last_surr_pred_S - S)  #/ (S + eps)
                logger.record_tabular("PredictionErrorLinearS", lin_pred_err)
                logger.record_tabular("PredictionErrorSurrogateS",
                                      surr_pred_err)
                self._last_lin_pred_S = next_linear_S
                self._last_surr_pred_S = next_surrogate_S

            else:
                logger.record_tabular("LinConstraintDelta", 0)
                logger.record_tabular("PredictedLinearS", 0)
                logger.record_tabular("PredictedSurrogateS", 0)
                logger.record_tabular("LinearSurrogateErr", 0)

                lin_pred_err = (self._last_lin_pred_S - 0)  #/ (S + eps)
                surr_pred_err = (self._last_surr_pred_S - 0)  #/ (S + eps)
                logger.record_tabular("PredictionErrorLinearS", lin_pred_err)
                logger.record_tabular("PredictionErrorSurrogateS",
                                      surr_pred_err)
                self._last_lin_pred_S = 0
                self._last_surr_pred_S = 0

        if stop_flag == True:
            record_zeros()
            wrap_up()
            return

        if optim_case == 1 and not (self.revert_to_last_safe_point):
            if self._linesearch_infeasible_recovery:
                logger.log(
                    "feasible recovery mode: constrained natural gradient step. performing linesearch on constraints."
                )
                line_search(False, True, True)
            else:
                self._target.set_param_values(prev_param - flat_descent_step,
                                              trainable=True)
                logger.log(
                    "feasible recovery mode: constrained natural gradient step. no linesearch performed."
                )
            check_nan()
            record_zeros()
            wrap_up()
            return
        elif optim_case == 0 and not (self.revert_to_last_safe_point):
            if self._linesearch_infeasible_recovery:
                logger.log(
                    "infeasible recovery mode: natural safety step. performing linesearch on constraints."
                )
                line_search(False, True, True)
            else:
                self._target.set_param_values(prev_param - flat_descent_step,
                                              trainable=True)
                logger.log(
                    "infeasible recovery mode: natural safety gradient step. no linesearch performed."
                )
            check_nan()
            record_zeros()
            wrap_up()
            return
        elif (optim_case == 0
              or optim_case == 1) and self.revert_to_last_safe_point:
            if self.last_safe_point:
                self._target.set_param_values(self.last_safe_point,
                                              trainable=True)
                logger.log(
                    "infeasible recovery mode: reverted to last safe point!")
            else:
                logger.log(
                    "alert! infeasible recovery mode failed: no last safe point to revert to."
                )
            record_zeros()
            wrap_up()
            return

        loss, quad_constraint_val, lin_constraint_val, n_iter = line_search()

        if (np.isnan(loss) or np.isnan(quad_constraint_val)
                or np.isnan(lin_constraint_val) or loss >= loss_before
                or quad_constraint_val >= self._max_quad_constraint_val
                or lin_constraint_val > lin_reject_threshold
            ) and not self._accept_violation:
            logger.log("Line search condition violated. Rejecting the step!")
            if np.isnan(loss):
                logger.log("Violated because loss is NaN")
            if np.isnan(quad_constraint_val):
                logger.log("Violated because quad_constraint %s is NaN" %
                           self._constraint_name_1)
            if np.isnan(lin_constraint_val):
                logger.log("Violated because lin_constraint %s is NaN" %
                           self._constraint_name_2)
            if loss >= loss_before:
                logger.log("Violated because loss not improving")
            if quad_constraint_val >= self._max_quad_constraint_val:
                logger.log("Violated because constraint %s is violated" %
                           self._constraint_name_1)
            if lin_constraint_val > lin_reject_threshold:
                logger.log(
                    "Violated because constraint %s exceeded threshold" %
                    self._constraint_name_2)
            self._target.set_param_values(prev_param, trainable=True)
        logger.log("backtrack iters: %d" % n_iter)
        logger.log("computing loss after")
        logger.log("optimization finished")
        wrap_up()
Beispiel #23
0
 def constraint_val(self, inputs, extra_inputs=None):
     inputs = tuple(inputs)
     if extra_inputs is None:
         extra_inputs = tuple()
     return sliced_fun(self._opt_fun["f_constraint"],
                       self._num_slices)(inputs, extra_inputs)
Beispiel #24
0
 def loss(self, inputs, extra_inputs=None):
     inputs = tuple(inputs)
     if extra_inputs is None:
         extra_inputs = tuple()
     return sliced_fun(self._opt_fun["f_loss"],
                       self._num_slices)(inputs, extra_inputs)
    def optimize(self,
                 inputs,
                 extra_inputs=None,
                 subsample_grouped_inputs=None):
        prev_param = np.copy(self._target.get_param_values(trainable=True))
        inputs = tuple(inputs)
        if extra_inputs is None:
            extra_inputs = tuple()

        if self._subsample_factor < 1:
            if subsample_grouped_inputs is None:
                subsample_grouped_inputs = [inputs]
            subsample_inputs = tuple()
            for inputs_grouped in subsample_grouped_inputs:
                n_samples = len(inputs_grouped[0])
                inds = np.random.choice(n_samples,
                                        int(n_samples *
                                            self._subsample_factor),
                                        replace=False)
                subsample_inputs += tuple([x[inds] for x in inputs_grouped])
        else:
            subsample_inputs = inputs

        logger.log(
            "Start CG optimization: #parameters: %d, #inputs: %d, #subsample_inputs: %d"
            % (len(prev_param), len(inputs[0]), len(subsample_inputs[0])))

        logger.log("computing loss before")
        loss_before = sliced_fun(self._opt_fun["f_loss"],
                                 self._num_slices)(inputs, extra_inputs)
        logger.log("performing update")

        logger.log("computing gradient")
        flat_g = sliced_fun(self._opt_fun["f_grad"],
                            self._num_slices)(inputs, extra_inputs)
        logger.log("gradient computed")

        logger.log("computing descent direction")
        Hx = self._hvp_approach.build_eval(subsample_inputs + extra_inputs)

        descent_direction = krylov.cg(Hx, flat_g, cg_iters=self._cg_iters)

        initial_step_size = np.sqrt(
            2.0 * self._max_constraint_val *
            (1. / (descent_direction.dot(Hx(descent_direction)) + 1e-8)))
        if np.isnan(initial_step_size):
            initial_step_size = 1.
        flat_descent_step = initial_step_size * descent_direction

        logger.log("descent direction computed")

        n_iter = 0
        for n_iter, ratio in enumerate(self._backtrack_ratio**np.arange(
                self._max_backtracks)):
            cur_step = ratio * flat_descent_step
            cur_param = prev_param - cur_step
            self._target.set_param_values(cur_param, trainable=True)
            loss, constraint_val = sliced_fun(
                self._opt_fun["f_loss_constraint"],
                self._num_slices)(inputs, extra_inputs)
            if self._debug_nan and np.isnan(constraint_val):
                import ipdb
                ipdb.set_trace()
            if loss < loss_before and constraint_val <= self._max_constraint_val:
                break
        if (np.isnan(loss) or np.isnan(constraint_val) or loss >= loss_before
                or constraint_val >= self._max_constraint_val
            ) and not self._accept_violation:
            logger.log("Line search condition violated. Rejecting the step!")
            if np.isnan(loss):
                logger.log("Violated because loss is NaN")
            if np.isnan(constraint_val):
                logger.log("Violated because constraint %s is NaN" %
                           self._constraint_name)
            if loss >= loss_before:
                logger.log("Violated because loss not improving")
            if constraint_val >= self._max_constraint_val:
                logger.log("Violated because constraint %s is violated" %
                           self._constraint_name)
            self._target.set_param_values(prev_param, trainable=True)
        logger.log("backtrack iters: %d" % n_iter)
        logger.log("computing loss after")
        logger.log("optimization finished")
 def eval(x):
     xs = tuple(self.target.flat_to_params(x, trainable=True))
     ret = sliced_fun(self.opt_fun["f_Hx_plain"], self._num_slices)(inputs, xs) + self.reg_coeff * x
     return ret
    def optimize(self, inputs, extra_inputs=None, subsample_grouped_inputs=None):
        prev_param = np.copy(self._target.get_param_values(trainable=True))
        inputs = tuple(inputs)
        if extra_inputs is None:
            extra_inputs = tuple()

        if self._subsample_factor < 1:
            if subsample_grouped_inputs is None:
                subsample_grouped_inputs = [inputs]
            subsample_inputs = tuple()
            for inputs_grouped in subsample_grouped_inputs:
                n_samples = len(inputs_grouped[0])
                inds = np.random.choice(
                    n_samples, int(n_samples * self._subsample_factor), replace=False)
                subsample_inputs += tuple([x[inds] for x in inputs_grouped])
        else:
            subsample_inputs = inputs

        logger.log("Start CG optimization: #parameters: %d, #inputs: %d, #subsample_inputs: %d"%(len(prev_param),len(inputs[0]), len(subsample_inputs[0])))

        logger.log("computing loss before")
        loss_before = sliced_fun(self._opt_fun["f_loss"], self._num_slices)(inputs, extra_inputs)
        logger.log("performing update")

        logger.log("computing gradient")
        flat_g = sliced_fun(self._opt_fun["f_grad"], self._num_slices)(inputs, extra_inputs)
        logger.log("gradient computed")

        logger.log("computing descent direction")
        Hx = self._hvp_approach.build_eval(subsample_inputs + extra_inputs)

        descent_direction = krylov.cg(Hx, flat_g, cg_iters=self._cg_iters)

        initial_step_size = np.sqrt(
            2.0 * self._max_constraint_val * (1. / (descent_direction.dot(Hx(descent_direction)) + 1e-8))
        )
        if np.isnan(initial_step_size):
            initial_step_size = 1.
        flat_descent_step = initial_step_size * descent_direction

        logger.log("descent direction computed")

        n_iter = 0
        for n_iter, ratio in enumerate(self._backtrack_ratio ** np.arange(self._max_backtracks)):
            cur_step = ratio * flat_descent_step
            cur_param = prev_param - cur_step
            self._target.set_param_values(cur_param, trainable=True)
            loss, constraint_val = sliced_fun(self._opt_fun["f_loss_constraint"], self._num_slices)(inputs,
                                                                                                    extra_inputs)
            if self._debug_nan and np.isnan(constraint_val):
                import ipdb;
                ipdb.set_trace()
            if loss < loss_before and constraint_val <= self._max_constraint_val:
                break
        if (np.isnan(loss) or np.isnan(constraint_val) or loss >= loss_before or constraint_val >=
            self._max_constraint_val) and not self._accept_violation:
            logger.log("Line search condition violated. Rejecting the step!")
            if np.isnan(loss):
                logger.log("Violated because loss is NaN")
            if np.isnan(constraint_val):
                logger.log("Violated because constraint %s is NaN" % self._constraint_name)
            if loss >= loss_before:
                logger.log("Violated because loss not improving")
            if constraint_val >= self._max_constraint_val:
                logger.log("Violated because constraint %s is violated" % self._constraint_name)
            self._target.set_param_values(prev_param, trainable=True)
        logger.log("backtrack iters: %d" % n_iter)
        logger.log("computing loss after")
        logger.log("optimization finished")
 def constraint_val(self, inputs, extra_inputs=None):
     inputs = tuple(inputs)
     if extra_inputs is None:
         extra_inputs = tuple()
     return sliced_fun(self._opt_fun["f_constraint"], self._num_slices)(inputs, extra_inputs)
 def loss(self, inputs, extra_inputs=None):
     inputs = tuple(inputs)
     if extra_inputs is None:
         extra_inputs = tuple()
     return sliced_fun(self._opt_fun["f_loss"], self._num_slices)(inputs, extra_inputs)
Beispiel #30
0
    def optimize(self, 
                 inputs, 
                 extra_inputs=None, 
                 subsample_grouped_inputs=None, 
                 precomputed_eval=None, 
                 precomputed_threshold=None,
                 diff_threshold=False,
                 inputs2=None,
                 extra_inputs2=None,
                ):

        inputs = tuple(inputs)
        if extra_inputs is None:
            extra_inputs = tuple()

        if inputs2 is None:
            inputs2 = inputs
        if extra_inputs2 is None:
            extra_inputs2 = tuple()

        def subsampled_inputs(inputs,subsample_grouped_inputs):
            if self._subsample_factor < 1:
                if subsample_grouped_inputs is None:
                    subsample_grouped_inputs = [inputs]
                subsample_inputs = tuple()
                for inputs_grouped in subsample_grouped_inputs:
                    n_samples = len(inputs_grouped[0])
                    inds = np.random.choice(
                        n_samples, int(n_samples * self._subsample_factor), replace=False)
                    subsample_inputs += tuple([x[inds] for x in inputs_grouped])
            else:
                subsample_inputs = inputs
            return subsample_inputs

        subsample_inputs = subsampled_inputs(inputs,subsample_grouped_inputs)
        if self._resample_inputs:
            subsample_inputs2 = subsampled_inputs(inputs,subsample_grouped_inputs)

        loss_before = sliced_fun(self._opt_fun["f_loss"], self._num_slices)(
            inputs, extra_inputs)

        flat_g = sliced_fun(self._opt_fun["f_grad"], self._num_slices)(
            inputs, extra_inputs)
        flat_b = sliced_fun(self._opt_fun["f_lin_constraint_grad"], self._num_slices)(
            inputs2, extra_inputs2)

        Hx = self._hvp_approach.build_eval(subsample_inputs + extra_inputs)
        v = krylov.cg(Hx, flat_g, cg_iters=self._cg_iters, verbose=self._verbose_cg)

        approx_g = Hx(v)
        q = v.dot(approx_g) 
        delta = 2 * self._max_quad_constraint_val
 
        eps = 1e-8

        residual = np.sqrt((approx_g - flat_g).dot(approx_g - flat_g))
        rescale  = q / (v.dot(v))

        if self.precompute:
            S = precomputed_eval
            assert(np.ndim(S)==0) 
        else:
            S = sliced_fun(self._opt_fun["lin_constraint"], self._num_slices)(inputs, extra_inputs) 

        c = S - self._max_lin_constraint_val
        if c > 0:
            logger.log("warning! safety constraint is already violated")
        else:
            self.last_safe_point = np.copy(self._target.get_param_values(trainable=True))

        stop_flag = False

        if flat_b.dot(flat_b) <= eps :
            lam = np.sqrt(q / delta)
            nu = 0
            w = 0
            r,s,A,B = 0,0,0,0
            optim_case = 4
        else:
            if self._resample_inputs:
                Hx = self._hvp_approach.build_eval(subsample_inputs2 + extra_inputs)

            norm_b = np.sqrt(flat_b.dot(flat_b))
            unit_b = flat_b / norm_b
            w = norm_b * krylov.cg(Hx, unit_b, cg_iters=self._cg_iters, verbose=self._verbose_cg)

            r = w.dot(approx_g) # approx = b^T H^{-1} g
            s = w.dot(Hx(w))    # approx = b^T H^{-1} b

            A = q - r**2 / s                # this should always be positive by Cauchy-Schwarz
            B = delta - c**2 / s            # this one says whether or not the closest point on the plane is feasible

            if c <0 and B < 0:
                optim_case = 3
            elif c < 0 and B > 0:
                optim_case = 2
            elif c > 0 and B > 0:
                optim_case = 1
                if self.attempt_feasible_recovery:
                    logger.log("alert! conjugate constraint optimizer is attempting feasible recovery")
                else:
                    logger.log("alert! problem is feasible but needs recovery, and we were instructed not to attempt recovery")
                    stop_flag = True
            else:
                optim_case = 0
                if self.attempt_infeasible_recovery:
                    logger.log("alert! conjugate constraint optimizer is attempting infeasible recovery")
                else:
                    logger.log("alert! problem is infeasible, and we were instructed not to attempt recovery")
                    stop_flag = True

            lam = np.sqrt(q / delta)
            nu  = 0

            if optim_case == 2 or optim_case == 1:
                lam_mid = r / c
                L_mid = - 0.5 * (q / lam_mid + lam_mid * delta)

                lam_a = np.sqrt(A / (B + eps))
                L_a = -np.sqrt(A*B) - r*c / (s + eps)                 

                lam_b = np.sqrt(q / delta)
                L_b = -np.sqrt(q * delta)

                if lam_mid > 0:
                    if c < 0:
                        if lam_a > lam_mid:
                            lam_a = lam_mid
                            L_a   = L_mid
                        if lam_b < lam_mid:
                            lam_b = lam_mid
                            L_b   = L_mid
                    else:
                        if lam_a < lam_mid:
                            lam_a = lam_mid
                            L_a   = L_mid
                        if lam_b > lam_mid:
                            lam_b = lam_mid
                            L_b   = L_mid

                    if L_a >= L_b:
                        lam = lam_a
                    else:
                        lam = lam_b

                else:
                    if c < 0:
                        lam = lam_b
                    else:
                        lam = lam_a

                nu = max(0, lam * c - r) / (s + eps)
        nextS = S + np.sqrt(delta * s)


        def record_zeros():
            logger.record_tabular("BacktrackIters", 0)
            logger.record_tabular("LossRejects", 0)
            logger.record_tabular("QuadRejects", 0)
            logger.record_tabular("LinRejects", 0)


        if optim_case > 0:
            flat_descent_step = (1. / (lam + eps) ) * ( v + nu * w )
        else:
            flat_descent_step = np.sqrt(delta / (s + eps)) * w

        prev_param = np.copy(self._target.get_param_values(trainable=True))

        prev_lin_constraint_val = sliced_fun(
            self._opt_fun["f_lin_constraint"], self._num_slices)(inputs, extra_inputs)

        lin_reject_threshold = self._max_lin_constraint_val
        if precomputed_threshold is not None:
            lin_reject_threshold = precomputed_threshold
        if diff_threshold:
            lin_reject_threshold += prev_lin_constraint_val

        def check_nan():
            loss, quad_constraint_val, lin_constraint_val = sliced_fun(
                self._opt_fun["f_loss_constraint"], self._num_slices)(inputs, extra_inputs)
            if np.isnan(loss) or np.isnan(quad_constraint_val) or np.isnan(lin_constraint_val):
                if np.isnan(loss):
                    logger.log("Violated because loss is NaN")
                if np.isnan(quad_constraint_val):
                    logger.log("Violated because quad_constraint %s is NaN" %
                               self._constraint_name_1)
                if np.isnan(lin_constraint_val):
                    logger.log("Violated because lin_constraint %s is NaN" %
                               self._constraint_name_2)
                self._target.set_param_values(prev_param, trainable=True)

        def line_search(check_loss=True, check_quad=True, check_lin=True):
            loss_rejects = 0
            quad_rejects = 0
            lin_rejects  = 0
            n_iter = 0
            for n_iter, ratio in enumerate(self._backtrack_ratio ** np.arange(self._max_backtracks)):
                cur_step = ratio * flat_descent_step
                cur_param = prev_param - cur_step
                self._target.set_param_values(cur_param, trainable=True)
                loss, quad_constraint_val, lin_constraint_val = sliced_fun(
                    self._opt_fun["f_loss_constraint"], self._num_slices)(inputs, extra_inputs)
                loss_flag = loss < loss_before
                quad_flag = quad_constraint_val <= self._max_quad_constraint_val
                lin_flag  = lin_constraint_val  <= lin_reject_threshold
                if check_loss and not(loss_flag):
                    loss_rejects += 1
                if check_quad and not(quad_flag):
                    quad_rejects += 1
                if check_lin and not(lin_flag):
                    lin_rejects += 1

                if (loss_flag or not(check_loss)) and (quad_flag or not(check_quad)) and (lin_flag or not(check_lin)):
                    break

            return loss, quad_constraint_val, lin_constraint_val, n_iter


        def wrap_up():
            if optim_case < 4:
                lin_constraint_val = sliced_fun(
                    self._opt_fun["f_lin_constraint"], self._num_slices)(inputs, extra_inputs)
                lin_constraint_delta = lin_constraint_val - prev_lin_constraint_val
                logger.record_tabular("LinConstraintDelta",lin_constraint_delta)

                cur_param = self._target.get_param_values()
                
                next_linear_S = S + flat_b.dot(cur_param - prev_param)
                next_surrogate_S = S + lin_constraint_delta

                lin_surrogate_acc = 100.*(next_linear_S - next_surrogate_S) / next_surrogate_S

                lin_pred_err = (self._last_lin_pred_S - S) #/ (S + eps)
                surr_pred_err = (self._last_surr_pred_S - S) #/ (S + eps)
                self._last_lin_pred_S = next_linear_S
                self._last_surr_pred_S = next_surrogate_S

            else:
                lin_pred_err = (self._last_lin_pred_S - 0) #/ (S + eps)
                surr_pred_err = (self._last_surr_pred_S - 0) #/ (S + eps)
                self._last_lin_pred_S = 0
                self._last_surr_pred_S = 0

        if stop_flag==True:
            record_zeros()
            wrap_up()
            return

        if optim_case == 1 and not(self.revert_to_last_safe_point):
            if self._linesearch_infeasible_recovery:
                logger.log("feasible recovery mode: constrained natural gradient step. performing linesearch on constraints.")
                line_search(False,True,True)
            else:
                self._target.set_param_values(prev_param - flat_descent_step, trainable=True)
                logger.log("feasible recovery mode: constrained natural gradient step. no linesearch performed.")
            check_nan()
            record_zeros()
            wrap_up()
            return
        elif optim_case == 0 and not(self.revert_to_last_safe_point):
            if self._linesearch_infeasible_recovery:
                logger.log("infeasible recovery mode: natural safety step. performing linesearch on constraints.")
                line_search(False,True,True)
            else:
                self._target.set_param_values(prev_param - flat_descent_step, trainable=True)
                logger.log("infeasible recovery mode: natural safety gradient step. no linesearch performed.")
            check_nan()
            record_zeros()
            wrap_up()
            return
        elif (optim_case == 0 or optim_case == 1) and self.revert_to_last_safe_point:
            if self.last_safe_point:
                self._target.set_param_values(self.last_safe_point, trainable=True)
                logger.log("infeasible recovery mode: reverted to last safe point!")
            else:
                logger.log("alert! infeasible recovery mode failed: no last safe point to revert to.")
            record_zeros()
            wrap_up()
            return


        loss, quad_constraint_val, lin_constraint_val, n_iter = line_search()

        if (np.isnan(loss) or np.isnan(quad_constraint_val) or np.isnan(lin_constraint_val) or loss >= loss_before 
            or quad_constraint_val >= self._max_quad_constraint_val
            or lin_constraint_val > lin_reject_threshold) and not self._accept_violation:
            self._target.set_param_values(prev_param, trainable=True)
        wrap_up()
Beispiel #31
0
 def f_opt_wrapper(flat_params):
     self._target.set_param_values(flat_params, trainable=True)
     return sliced_fun(f_opt, self._n_slices)(inputs)
Beispiel #32
0
    def optimize(self,
                 inputs,
                 extra_inputs=None,
                 subsample_grouped_inputs=None):

        if len(inputs) == 0:
            raise NotImplementedError

        f_loss = self._opt_fun["f_loss"]
        f_grad = self._opt_fun["f_grad"]
        f_grad_tilde = self._opt_fun["f_grad_tilde"]

        inputs = tuple(inputs)
        if extra_inputs is None:
            extra_inputs = tuple()
        else:
            extra_inputs = tuple(extra_inputs)

        param = np.copy(self._target.get_param_values(trainable=True))
        logger.log(
            "Start SVRG CG subsample optimization: #parameters: %d, #inputs: %d, #subsample_inputs: %d"
            % (len(param), len(
                inputs[0]), self._subsample_factor * len(inputs[0])))

        subsamples = BatchDataset(inputs,
                                  int(self._subsample_factor * len(inputs[0])),
                                  extra_inputs=extra_inputs)

        dataset = BatchDataset(inputs,
                               self._batch_size,
                               extra_inputs=extra_inputs)

        for epoch in range(self._max_epochs):
            if self._verbose:
                logger.log("Epoch %d" % (epoch))
                progbar = pyprind.ProgBar(len(inputs[0]))
            # g_u = 1/n \sum_{b} \partial{loss(w_tidle, b)} {w_tidle}
            grad_sum = np.zeros_like(param)
            g_mean_tilde = sliced_fun(f_grad_tilde,
                                      self._num_slices)(inputs, extra_inputs)
            logger.record_tabular('g_mean_tilde', LA.norm(g_mean_tilde))
            print("-------------mini-batch-------------------")
            num_batch = 0
            while num_batch < self._max_batch:
                batch = dataset.random_batch()
                # todo, pick mini-batch with weighted prob.
                if self._use_SGD:
                    g = f_grad(*(batch))
                else:
                    g = f_grad(*(batch)) - \
                            f_grad_tilde(*(batch)) + g_mean_tilde
                grad_sum += g
                subsample_inputs = subsamples.random_batch()
                pdb.set_trace()
                Hx = self._hvp_approach.build_eval(subsample_inputs)
                self.conjugate_grad(g, Hx, inputs, extra_inputs)
                num_batch += 1
            print("max batch achieved {:}".format(num_batch))
            grad_sum /= 1.0 * num_batch
            if self._verbose:
                progbar.update(batch[0].shape[0])
            logger.record_tabular('gdist', LA.norm(grad_sum - g_mean_tilde))

            cur_w = np.copy(self._target.get_param_values(trainable=True))
            w_tilde = self._target_tilde.get_param_values(trainable=True)
            self._target_tilde.set_param_values(cur_w, trainable=True)
            logger.record_tabular('wnorm', LA.norm(cur_w))
            logger.record_tabular('w_dist',
                                  LA.norm(cur_w - w_tilde) / LA.norm(cur_w))

            if self._verbose:
                if progbar.active:
                    progbar.stop()

            if abs(LA.norm(cur_w - w_tilde) /
                   LA.norm(cur_w)) < self._tolerance:
                break
Beispiel #33
0
 def loss(self, inputs, extra_inputs=None):
     if extra_inputs is None:
         extra_inputs = list()
     # return self._opt_fun["f_loss"](*(list(inputs) + list(extra_inputs)))
     return sliced_fun(self._opt_fun["f_loss"],
                       self._n_slices)(inputs, extra_inputs)
Beispiel #34
0
 def loss(self, inputs, extra_inputs=None):
     shareds, barriers = self._par_objs
     shareds.loss[self.rank] = self.avg_fac * ext.sliced_fun(
         self._opt_fun["f_loss"], self._num_slices)(inputs, extra_inputs)
     barriers.loss.wait()
     return sum(shareds.loss)