Ejemplo n.º 1
0
    def optimize(self,
                 inputs,
                 extra_inputs=None,
                 subsample_grouped_inputs=None):

        if len(inputs) == 0:
            raise NotImplementedError

        f_loss = self._opt_fun["f_loss"]
        f_grad = self._opt_fun["f_grad"]
        f_var_grad = self._opt_fun["f_var_grad"]

        inputs = tuple(inputs)
        if extra_inputs is None:
            extra_inputs = tuple()
        else:
            extra_inputs = tuple(extra_inputs)

        param = np.copy(
            self._target.get_mean_network().get_param_values(trainable=True))
        logger.log(
            "Start SVRG CG subsample optimization: #parameters: %d, #inputs: %d, #subsample_inputs: %d"
            % (len(param), len(
                inputs[0]), self._subsample_factor * len(inputs[0])))

        subsamples = BatchDataset(inputs,
                                  int(self._subsample_factor * len(inputs[0])),
                                  extra_inputs=extra_inputs)

        dataset = BatchDataset(inputs,
                               self._batch_size,
                               extra_inputs=extra_inputs)

        for epoch in range(self._max_epochs):
            if self._verbose:
                logger.log("Epoch %d" % (epoch))
                progbar = pyprind.ProgBar(len(inputs[0]))
            num_batch = 0
            while num_batch < self._max_batch:
                batch = dataset.random_batch()
                subsample_inputs = subsamples.random_batch()
                g = f_grad(*(batch))
                mean_Hx = self._mean_hvp.build_eval(subsample_inputs)
                self._target_network = self._target.get_mean_network()
                self.conjugate_grad(g, mean_Hx, inputs, extra_inputs)

                # update var_network weights
                var_g = f_var_grad(*(batch))
                var_Hx = self._var_hvp.build_eval(subsample_inputs)
                self._target_network = self._target.get_std_network()
                self.conjugate_grad(var_g, var_Hx, inputs, extra_inputs)
                num_batch += 1
            print("max batch achieved {:}".format(num_batch))
            if self._verbose:
                progbar.update(batch[0].shape[0])

            if self._verbose:
                if progbar.active:
                    progbar.stop()
Ejemplo n.º 2
0
    def optimize(self, inputs, extra_inputs=None,
                 subsample_grouped_inputs=None):

        if len(inputs) == 0:
            raise NotImplementedError

        f_loss = self._opt_fun["f_loss"]
        f_grad = self._opt_fun["f_grad"]

        inputs = tuple(inputs)
        if extra_inputs is None:
            extra_inputs = tuple()
        else:
            extra_inputs = tuple(extra_inputs)

        param = np.copy(self._target.get_param_values(trainable=True))
        logger.log("Start SVRG CG subsample optimization: #parameters: %d, #inputs: %d, #subsample_inputs: %d" % (
            len(param), len(inputs[0]), self._subsample_factor * len(inputs[0])))

        subsamples = BatchDataset(
            inputs,
            int(self._subsample_factor * len(inputs[0])),
            extra_inputs=extra_inputs)

        dataset = BatchDataset(
            inputs,
            self._batch_size,
            extra_inputs=extra_inputs)

        for epoch in range(self._max_epochs):
            if self._verbose:
                logger.log("Epoch %d" % (epoch))
                progbar = pyprind.ProgBar(len(inputs[0]))
            # g_u = 1/n \sum_{b} \partial{loss(w_tidle, b)} {w_tidle}
            print("-------------mini-batch-------------------")
            num_batch = 0
            while num_batch < self._max_batch:
                batch = dataset.random_batch()
                # todo, pick mini-batch with weighted prob.
                g = f_grad(*(batch))
                subsample_inputs = subsamples.random_batch()
                Hx = self._hvp_approach.build_eval(subsample_inputs)
                self.conjugate_grad(g, Hx, inputs, extra_inputs)
                num_batch += 1
            print("max batch achieved {:}".format(num_batch))
            if self._verbose:
                progbar.update(batch[0].shape[0])

            cur_w = np.copy(self._target.get_param_values(trainable=True))
            logger.record_tabular('wnorm', LA.norm(cur_w))

            if self._verbose:
                if progbar.active:
                    progbar.stop()
Ejemplo n.º 3
0
    def optimize(self, inputs, extra_inputs=None, callback=None):
        if len(inputs) == 0:
            raise NotImplementedError

        f_loss = self._opt_fun["f_loss"]
        f_grad = self._opt_fun["f_grad"]
        f_var_grad = self._opt_fun["f_var_grad"]

        inputs = tuple(inputs)
        if extra_inputs is None:
            extra_inputs = tuple()
        else:
            extra_inputs = tuple(extra_inputs)

        dataset = BatchDataset(inputs,
                               self._batch_size,
                               extra_inputs=extra_inputs)

        mean_w = self._target.get_mean_network().get_param_values(
            trainable=True)
        var_w = self._target.get_std_network().get_param_values(trainable=True)
        for epoch in range(self._max_epochs):
            if self._verbose:
                logger.log("Epoch %d" % (epoch))
                progbar = pyprind.ProgBar(len(inputs[0]))

            num_batch = 0
            loss = f_loss(*(tuple(inputs)) + extra_inputs)
            while num_batch < self._max_batch:
                batch = dataset.random_batch()
                g = f_grad(*(batch))
                # w = w - \eta g
                # pdb.set_trace()
                mean_w = mean_w - self._learning_rate * g
                self._target.get_mean_network().set_param_values(
                    mean_w, trainable=True)

                # pdb.set_trace()
                g_var = f_var_grad(*(batch))
                var_w = var_w - self._learning_rate * g_var
                self._target.get_std_network().set_param_values(var_w,
                                                                trainable=True)

                new_loss = f_loss(*(tuple(inputs) + extra_inputs))
                print("mean: batch {:} grad {:}, weight {:}".format(
                    num_batch, LA.norm(g), LA.norm(mean_w)))
                print(
                    "var: batch {:}, loss {:}, diff loss{:}, grad {:}, weight {:}"
                    .format(num_batch, new_loss, new_loss - loss,
                            LA.norm(g_var), LA.norm(var_w)))
                loss = new_loss
                num_batch += 1
Ejemplo n.º 4
0
    def optimize(self,
                 inputs,
                 extra_inputs=None,
                 subsample_grouped_inputs=None):

        if len(inputs) == 0:
            raise NotImplementedError

        f_loss = self._opt_fun["f_loss"]
        f_grad = self._opt_fun["f_grad"]
        f_grad_tilde = self._opt_fun["f_grad_tilde"]

        inputs = tuple(inputs)
        if extra_inputs is None:
            extra_inputs = tuple()
        else:
            extra_inputs = tuple(extra_inputs)

        param = np.copy(self._target.get_param_values(trainable=True))
        logger.log(
            "Start SVRG CG subsample optimization: #parameters: %d, #inputs: %d, #subsample_inputs: %d"
            % (len(param), len(
                inputs[0]), self._subsample_factor * len(inputs[0])))

        subsamples = BatchDataset(inputs,
                                  int(self._subsample_factor * len(inputs[0])),
                                  extra_inputs=extra_inputs)

        dataset = BatchDataset(inputs,
                               self._batch_size,
                               extra_inputs=extra_inputs)

        for epoch in range(self._max_epochs):
            if self._verbose:
                logger.log("Epoch %d" % (epoch))
                progbar = pyprind.ProgBar(len(inputs[0]))
            # g_u = 1/n \sum_{b} \partial{loss(w_tidle, b)} {w_tidle}
            grad_sum = np.zeros_like(param)
            g_mean_tilde = sliced_fun(f_grad_tilde,
                                      self._num_slices)(inputs, extra_inputs)
            logger.record_tabular('g_mean_tilde', LA.norm(g_mean_tilde))
            print("-------------mini-batch-------------------")
            num_batch = 0
            while num_batch < self._max_batch:
                batch = dataset.random_batch()
                # todo, pick mini-batch with weighted prob.
                if self._use_SGD:
                    g = f_grad(*(batch))
                else:
                    g = f_grad(*(batch)) - \
                            f_grad_tilde(*(batch)) + g_mean_tilde
                grad_sum += g
                subsample_inputs = subsamples.random_batch()
                pdb.set_trace()
                Hx = self._hvp_approach.build_eval(subsample_inputs)
                self.conjugate_grad(g, Hx, inputs, extra_inputs)
                num_batch += 1
            print("max batch achieved {:}".format(num_batch))
            grad_sum /= 1.0 * num_batch
            if self._verbose:
                progbar.update(batch[0].shape[0])
            logger.record_tabular('gdist', LA.norm(grad_sum - g_mean_tilde))

            cur_w = np.copy(self._target.get_param_values(trainable=True))
            w_tilde = self._target_tilde.get_param_values(trainable=True)
            self._target_tilde.set_param_values(cur_w, trainable=True)
            logger.record_tabular('wnorm', LA.norm(cur_w))
            logger.record_tabular('w_dist',
                                  LA.norm(cur_w - w_tilde) / LA.norm(cur_w))

            if self._verbose:
                if progbar.active:
                    progbar.stop()

            if abs(LA.norm(cur_w - w_tilde) /
                   LA.norm(cur_w)) < self._tolerance:
                break