コード例 #1
0
    def optimize(self, inputs, extra_inputs=None):

        if extra_inputs is None:
            extra_inputs = list()

#         import ipdb; ipdb.set_trace()
        dataset = BatchDataset(inputs=inputs, batch_size=self._batch_size, extra_inputs=extra_inputs)
        cg_dataset = BatchDataset(inputs=inputs, batch_size=self._cg_batch_size, extra_inputs=extra_inputs)

        itr = [0]
        start_time = time.time()

        if self._callback:
            def opt_callback():
                loss = self._opt_fun["f_loss"](*(inputs + extra_inputs))
                elapsed = time.time() - start_time
                self._callback(dict(
                    loss=loss,
                    params=self._target.get_param_values(trainable=True),
                    itr=itr[0],
                    elapsed=elapsed,
                ))
                itr[0] += 1
        else:
            opt_callback = None

        self._hf_optimizer.train(
            gradient_dataset=dataset,
            cg_dataset=cg_dataset,
            itr_callback=opt_callback,
            num_updates=self._max_opt_itr,
            preconditioner=True,
            verbose=True
        )
コード例 #2
0
    def optimize(self, inputs, extra_inputs=None, callback=None):
        if len(inputs) == 0:
            # Assumes that we should always sample mini-batches
            raise NotImplementedError

        f_loss = self._opt_fun["f_loss"]

        if extra_inputs is None:
            extra_inputs = tuple()

        last_loss = f_loss(*(tuple(inputs) + extra_inputs))

        start_time = time.time()

        dataset = BatchDataset(inputs,
                               self._batch_size,
                               extra_inputs=extra_inputs)

        sess = tf.get_default_session()

        for epoch in range(self._max_epochs):
            if self._verbose:
                logger.log("Epoch %d" % (epoch))
                progbar = pyprind.ProgBar(len(inputs[0]))

            gc.collect()
            for batch in dataset.iterate(update=True):
                sess.run(self._train_op,
                         dict(list(zip(self._input_vars, batch))))
                if self._verbose:
                    progbar.update(len(batch[0]))
            gc.collect()

            if self._verbose:
                if progbar.active:
                    progbar.stop()

            new_loss = f_loss(*(tuple(inputs) + extra_inputs))

            if self._verbose:
                logger.log("Epoch: %d | Loss: %f" % (epoch, new_loss))
            if self._callback or callback:
                elapsed = time.time() - start_time
                callback_args = dict(
                    loss=new_loss,
                    params=self._target.get_param_values(
                        trainable=True) if self._target else None,
                    itr=epoch,
                    elapsed=elapsed,
                )
                if self._callback:
                    self._callback(callback_args)
                if callback:
                    callback(**callback_args)

            if abs(last_loss - new_loss) < self._tolerance:
                break
            last_loss = new_loss
コード例 #3
0
    def optimize(self, inputs, extra_inputs=None, callback=None):

        if len(inputs) == 0:
            # Assumes that we should always sample mini-batches
            raise NotImplementedError

        f_loss = self._opt_fun["f_loss"]

        if extra_inputs is None:
            extra_inputs = tuple()

        last_loss = f_loss(*(tuple(inputs) + extra_inputs))

        start_time = time.time()

        dataset = BatchDataset(inputs, self._batch_size, extra_inputs=extra_inputs)

        sess = tf.get_default_session()

        for epoch in range(self._max_epochs):
            if self._verbose:
                logger.log("Epoch %d" % (epoch))
                progbar = pyprind.ProgBar(len(inputs[0]))

            for batch in dataset.iterate(update=True):
                if self._init_train_op is not None:
                    sess.run(self._init_train_op, dict(list(zip(self._input_vars, batch))))
                    self._init_train_op = None  # only use it once
                else:
                    sess.run(self._train_op, dict(list(zip(self._input_vars, batch))))

                if self._verbose:
                    progbar.update(len(batch[0]))

            if self._verbose:
                if progbar.active:
                    progbar.stop()

            new_loss = f_loss(*(tuple(inputs) + extra_inputs))

            if self._verbose:
                logger.log("Epoch: %d | Loss: %f" % (epoch, new_loss))
            if self._callback or callback:
                elapsed = time.time() - start_time
                callback_args = dict(
                    loss=new_loss,
                    params=self._target.get_param_values(trainable=True) if self._target else None,
                    itr=epoch,
                    elapsed=elapsed,
                )
                if self._callback:
                    self._callback(callback_args)
                if callback:
                    callback(**callback_args)

            if abs(last_loss - new_loss) < self._tolerance:
                break
            last_loss = new_loss
コード例 #4
0
    def optimize_gen(self,
                     inputs,
                     extra_inputs=None,
                     callback=None,
                     yield_itr=None):

        if len(inputs) == 0:
            # Assumes that we should always sample mini-batches
            raise NotImplementedError

        f_opt = self._opt_fun["f_opt"]
        f_loss = self._opt_fun["f_loss"]

        if extra_inputs is None:
            extra_inputs = tuple()

        last_loss = f_loss(*(tuple(inputs) + extra_inputs))

        start_time = time.time()

        dataset = BatchDataset(inputs,
                               self._batch_size,
                               extra_inputs=extra_inputs
                               #, randomized=self._randomized
                               )

        itr = 0
        for epoch in pyprind.prog_bar(list(range(self._max_epochs))):
            for batch in dataset.iterate(update=True):
                f_opt(*batch)
                if yield_itr is not None and (itr % (yield_itr + 1)) == 0:
                    yield
                itr += 1

            new_loss = self.loss(inputs, extra_inputs)
            if self._verbose:
                logger.log("Epoch %d, loss %s" % (epoch, new_loss))

            if self._callback or callback:
                elapsed = time.time() - start_time
                callback_args = dict(
                    loss=new_loss,
                    params=self._target.get_param_values(
                        trainable=True) if self._target else None,
                    itr=epoch,
                    elapsed=elapsed,
                )
                if self._callback:
                    self._callback(callback_args)
                if callback:
                    callback(**callback_args)

            if abs(last_loss - new_loss) < self._tolerance:
                break
            last_loss = new_loss
コード例 #5
0
ファイル: stein_sgd_optimizer.py プロジェクト: wolfustc/rllab
    def optimize(self, inputs, extra_inputs=None, callback=None):
        if len(inputs) == 0:
            raise NotImplementedError

        f_loss = self._opt_fun["f_loss"]
        f_grad = self._opt_fun["f_grad"]
        f_var_grad = self._opt_fun["f_var_grad"]

        inputs = tuple(inputs)
        if extra_inputs is None:
            extra_inputs = tuple()
        else:
            extra_inputs = tuple(extra_inputs)

        dataset = BatchDataset(inputs,
                               self._batch_size,
                               extra_inputs=extra_inputs)

        mean_w = self._target.get_mean_network().get_param_values(
            trainable=True)
        var_w = self._target.get_std_network().get_param_values(trainable=True)
        for epoch in range(self._max_epochs):
            if self._verbose:
                logger.log("Epoch %d" % (epoch))
                progbar = pyprind.ProgBar(len(inputs[0]))

            num_batch = 0
            loss = f_loss(*(tuple(inputs)) + extra_inputs)
            while num_batch < self._max_batch:
                batch = dataset.random_batch()
                g = f_grad(*(batch))
                # w = w - \eta g
                # pdb.set_trace()
                mean_w = mean_w - self._learning_rate * g
                self._target.get_mean_network().set_param_values(
                    mean_w, trainable=True)

                # pdb.set_trace()
                g_var = f_var_grad(*(batch))
                var_w = var_w - self._learning_rate * g_var
                self._target.get_std_network().set_param_values(var_w,
                                                                trainable=True)

                new_loss = f_loss(*(tuple(inputs) + extra_inputs))
                print("mean: batch {:} grad {:}, weight {:}".format(
                    num_batch, LA.norm(g), LA.norm(mean_w)))
                print(
                    "var: batch {:}, loss {:}, diff loss{:}, grad {:}, weight {:}"
                    .format(num_batch, new_loss, new_loss - loss,
                            LA.norm(g_var), LA.norm(var_w)))
                loss = new_loss
                num_batch += 1
コード例 #6
0
    def optimize_gen(self, inputs, extra_inputs=None, callback=None, yield_itr=None):

        if len(inputs) == 0:
            # Assumes that we should always sample mini-batches
            raise NotImplementedError

        f_opt = self._opt_fun["f_opt"]
        f_loss = self._opt_fun["f_loss"]

        if extra_inputs is None:
            extra_inputs = tuple()

        last_loss = f_loss(*(tuple(inputs) + extra_inputs))

        start_time = time.time()

        dataset = BatchDataset(
            inputs, self._batch_size,
            extra_inputs=extra_inputs
            #, randomized=self._randomized
        )

        itr = 0
        for epoch in pyprind.prog_bar(list(range(self._max_epochs))):
            for batch in dataset.iterate(update=True):
                f_opt(*batch)
                if yield_itr is not None and (itr % (yield_itr+1)) == 0:
                    yield
                itr += 1

            new_loss = f_loss(*(tuple(inputs) + extra_inputs))
            if self._verbose:
                logger.log("Epoch %d, loss %s" % (epoch, new_loss))

            if self._callback or callback:
                elapsed = time.time() - start_time
                callback_args = dict(
                    loss=new_loss,
                    params=self._target.get_param_values(trainable=True) if self._target else None,
                    itr=epoch,
                    elapsed=elapsed,
                )
                if self._callback:
                    self._callback(callback_args)
                if callback:
                    callback(**callback_args)

            if abs(last_loss - new_loss) < self._tolerance:
                break
            last_loss = new_loss
コード例 #7
0
    def optimize(self, inputs, extra_inputs=None, callback=None):

        if len(inputs) == 0:
            # Assumes that we should always sample mini-batches
            raise NotImplementedError

        f_opt = self._opt_fun["f_opt"]
        f_loss = self._opt_fun["f_loss"]

        if extra_inputs is None:
            extra_inputs = tuple()

        last_loss = f_loss(*(tuple(inputs) + extra_inputs))

        start_time = time.time()

        dataset = BatchDataset(inputs,
                               self._batch_size,
                               extra_inputs=extra_inputs)

        for epoch in xrange(self._max_epochs):
            if self._verbose:
                logger.log("Epoch %d" % epoch)
            for batch in dataset.iterate(update=True):
                f_opt(*batch)

            new_loss = f_loss(*(tuple(inputs) + extra_inputs))

            if self._callback or callback:
                elapsed = time.time() - start_time
                callback_args = dict(
                    loss=new_loss,
                    params=self._target.get_param_values(
                        trainable=True) if self._target else None,
                    itr=epoch,
                    elapsed=elapsed,
                )
                if self._callback:
                    self._callback(callback_args)
                if callback:
                    callback(**callback_args)

            if abs(last_loss - new_loss) < self._tolerance:
                break
            last_loss = new_loss
コード例 #8
0
ファイル: first_order_optimizer.py プロジェクト: cknd/rllab
    def optimize(self, inputs, extra_inputs=None, callback=None):

        if len(inputs) == 0:
            # Assumes that we should always sample mini-batches
            raise NotImplementedError

        f_opt = self._opt_fun["f_opt"]
        f_loss = self._opt_fun["f_loss"]

        if extra_inputs is None:
            extra_inputs = tuple()

        last_loss = f_loss(*(tuple(inputs) + extra_inputs))

        start_time = time.time()

        dataset = BatchDataset(inputs, self._batch_size, extra_inputs=extra_inputs)

        for epoch in range(self._max_epochs):
            if self._verbose:
                logger.log("Epoch %d" % epoch)
            for batch in dataset.iterate(update=True):
                f_opt(*batch)

            new_loss = f_loss(*(tuple(inputs) + extra_inputs))

            if self._callback or callback:
                elapsed = time.time() - start_time
                callback_args = dict(
                    loss=new_loss,
                    params=self._target.get_param_values(trainable=True) if self._target else None,
                    itr=epoch,
                    elapsed=elapsed,
                )
                if self._callback:
                    self._callback(callback_args)
                if callback:
                    callback(**callback_args)

            if abs(last_loss - new_loss) < self._tolerance:
                break
            last_loss = new_loss
コード例 #9
0
ファイル: stein_optimizer.py プロジェクト: wolfustc/rllab
    def optimize(self,
                 inputs,
                 extra_inputs=None,
                 subsample_grouped_inputs=None):

        if len(inputs) == 0:
            raise NotImplementedError

        f_loss = self._opt_fun["f_loss"]
        f_grad = self._opt_fun["f_grad"]
        f_var_grad = self._opt_fun["f_var_grad"]

        inputs = tuple(inputs)
        if extra_inputs is None:
            extra_inputs = tuple()
        else:
            extra_inputs = tuple(extra_inputs)

        param = np.copy(
            self._target.get_mean_network().get_param_values(trainable=True))
        logger.log(
            "Start SVRG CG subsample optimization: #parameters: %d, #inputs: %d, #subsample_inputs: %d"
            % (len(param), len(
                inputs[0]), self._subsample_factor * len(inputs[0])))

        subsamples = BatchDataset(inputs,
                                  int(self._subsample_factor * len(inputs[0])),
                                  extra_inputs=extra_inputs)

        dataset = BatchDataset(inputs,
                               self._batch_size,
                               extra_inputs=extra_inputs)

        for epoch in range(self._max_epochs):
            if self._verbose:
                logger.log("Epoch %d" % (epoch))
                progbar = pyprind.ProgBar(len(inputs[0]))
            num_batch = 0
            while num_batch < self._max_batch:
                batch = dataset.random_batch()
                subsample_inputs = subsamples.random_batch()
                g = f_grad(*(batch))
                mean_Hx = self._mean_hvp.build_eval(subsample_inputs)
                self._target_network = self._target.get_mean_network()
                self.conjugate_grad(g, mean_Hx, inputs, extra_inputs)

                # update var_network weights
                var_g = f_var_grad(*(batch))
                var_Hx = self._var_hvp.build_eval(subsample_inputs)
                self._target_network = self._target.get_std_network()
                self.conjugate_grad(var_g, var_Hx, inputs, extra_inputs)
                num_batch += 1
            print("max batch achieved {:}".format(num_batch))
            if self._verbose:
                progbar.update(batch[0].shape[0])

            if self._verbose:
                if progbar.active:
                    progbar.stop()
コード例 #10
0
ファイル: robust_trpo.py プロジェクト: tianbingsz/SVRG
    def optimize(self, inputs, extra_inputs=None,
                 subsample_grouped_inputs=None):

        if len(inputs) == 0:
            raise NotImplementedError

        f_loss = self._opt_fun["f_loss"]
        f_grad = self._opt_fun["f_grad"]

        inputs = tuple(inputs)
        if extra_inputs is None:
            extra_inputs = tuple()
        else:
            extra_inputs = tuple(extra_inputs)

        param = np.copy(self._target.get_param_values(trainable=True))
        logger.log("Start SVRG CG subsample optimization: #parameters: %d, #inputs: %d, #subsample_inputs: %d" % (
            len(param), len(inputs[0]), self._subsample_factor * len(inputs[0])))

        subsamples = BatchDataset(
            inputs,
            int(self._subsample_factor * len(inputs[0])),
            extra_inputs=extra_inputs)

        dataset = BatchDataset(
            inputs,
            self._batch_size,
            extra_inputs=extra_inputs)

        for epoch in range(self._max_epochs):
            if self._verbose:
                logger.log("Epoch %d" % (epoch))
                progbar = pyprind.ProgBar(len(inputs[0]))
            # g_u = 1/n \sum_{b} \partial{loss(w_tidle, b)} {w_tidle}
            print("-------------mini-batch-------------------")
            num_batch = 0
            while num_batch < self._max_batch:
                batch = dataset.random_batch()
                # todo, pick mini-batch with weighted prob.
                g = f_grad(*(batch))
                subsample_inputs = subsamples.random_batch()
                Hx = self._hvp_approach.build_eval(subsample_inputs)
                self.conjugate_grad(g, Hx, inputs, extra_inputs)
                num_batch += 1
            print("max batch achieved {:}".format(num_batch))
            if self._verbose:
                progbar.update(batch[0].shape[0])

            cur_w = np.copy(self._target.get_param_values(trainable=True))
            logger.record_tabular('wnorm', LA.norm(cur_w))

            if self._verbose:
                if progbar.active:
                    progbar.stop()
コード例 #11
0
    def optimize(self,
                 inputs,
                 extra_inputs=None,
                 callback=None,
                 val_inputs=[None],
                 val_extra_inputs=tuple([None])):

        if len(inputs) == 0:
            # Assumes that we should always sample mini-batches
            raise NotImplementedError

        assert len(inputs) == 1

        dataset_size, _ = inputs[0].shape
        f_loss = self._opt_fun["f_loss"]

        # Plot individual costs from complexity / likelihood terms.
        try:
            use_c_loss = True
            c_loss = self._opt_fun["c_loss"]
        except KeyError:
            use_c_loss = False
        try:
            use_l_loss = True
            l_loss = self._opt_fun["l_loss"]
        except KeyError:
            use_l_loss = False

        if extra_inputs is None:
            extra_inputs = tuple()

        #last_loss = f_loss(*(tuple(inputs) + extra_inputs))

        start_time = time.time()

        train_dataset = BatchDataset(inputs,
                                     self._batch_size,
                                     extra_inputs=extra_inputs)
        if not all([vi is None for vi in val_inputs]):
            val_dataset = BatchDataset(val_inputs,
                                       self._batch_size,
                                       extra_inputs=val_extra_inputs)

        sess = tf.get_default_session()

        for epoch in range(self._max_epochs):
            if self._verbose:
                logger.log("Epoch %d" % (epoch))
                progbar = pyprind.ProgBar(len(inputs[0]))

            train_losses = []
            train_c_losses, train_l_losses = [], []
            # batch is a (matrix X, matrix Y) tuple
            for t, batch in enumerate(train_dataset.iterate(update=True)):
                sess.run(self._train_op,
                         dict(list(zip(self._input_vars, batch))))
                train_losses.append(f_loss(*batch))
                train_c_losses.append(c_loss(*batch))
                train_l_losses.append(l_loss(*batch))

                if self._verbose:
                    progbar.update(len(batch[0]))

            train_loss = np.mean(train_losses)
            train_c_loss = np.mean(train_c_losses)
            train_l_loss = np.mean(train_l_losses)

            val_losses = []
            if not all([vi is None for vi in val_inputs]):
                for t, batch in enumerate(val_dataset.iterate(update=True)):
                    val_losses.append(f_loss(*batch))

                val_loss = np.mean(val_losses)

            for interval, op in self._update_priors_ops:
                if interval != 0 and epoch % interval == 0:
                    sess.run(op)

            if self._verbose:
                if progbar.active:
                    progbar.stop()

            if self._verbose:
                logger.log("Epoch: %d | Loss: %f" % (epoch, train_loss))
            if self._callback or callback:
                elapsed = time.time() - start_time
                callback_args = dict(
                    loss=train_loss,
                    params=self._target.get_param_values(
                        trainable=True) if self._target else None,
                    itr=epoch,
                    elapsed=elapsed,
                )
                if use_c_loss:
                    callback_args['c_loss'] = train_c_loss
                if use_l_loss:
                    callback_args['l_loss'] = train_l_loss

                if val_loss is not None:
                    callback_args.update({'val_loss': val_loss})
                if self._callback:
                    self._callback(callback_args)
                if callback:
                    callback(**callback_args)
コード例 #12
0
    def optimize(self,
                 inputs,
                 extra_inputs=None,
                 subsample_grouped_inputs=None):

        if len(inputs) == 0:
            raise NotImplementedError

        f_loss = self._opt_fun["f_loss"]
        f_grad = self._opt_fun["f_grad"]
        f_grad_tilde = self._opt_fun["f_grad_tilde"]

        inputs = tuple(inputs)
        if extra_inputs is None:
            extra_inputs = tuple()
        else:
            extra_inputs = tuple(extra_inputs)

        param = np.copy(self._target.get_param_values(trainable=True))
        logger.log(
            "Start SVRG CG subsample optimization: #parameters: %d, #inputs: %d, #subsample_inputs: %d"
            % (len(param), len(
                inputs[0]), self._subsample_factor * len(inputs[0])))

        subsamples = BatchDataset(inputs,
                                  int(self._subsample_factor * len(inputs[0])),
                                  extra_inputs=extra_inputs)

        dataset = BatchDataset(inputs,
                               self._batch_size,
                               extra_inputs=extra_inputs)

        for epoch in range(self._max_epochs):
            if self._verbose:
                logger.log("Epoch %d" % (epoch))
                progbar = pyprind.ProgBar(len(inputs[0]))
            # g_u = 1/n \sum_{b} \partial{loss(w_tidle, b)} {w_tidle}
            grad_sum = np.zeros_like(param)
            g_mean_tilde = sliced_fun(f_grad_tilde,
                                      self._num_slices)(inputs, extra_inputs)
            logger.record_tabular('g_mean_tilde', LA.norm(g_mean_tilde))
            print("-------------mini-batch-------------------")
            num_batch = 0
            while num_batch < self._max_batch:
                batch = dataset.random_batch()
                # todo, pick mini-batch with weighted prob.
                if self._use_SGD:
                    g = f_grad(*(batch))
                else:
                    g = f_grad(*(batch)) - \
                            f_grad_tilde(*(batch)) + g_mean_tilde
                grad_sum += g
                subsample_inputs = subsamples.random_batch()
                pdb.set_trace()
                Hx = self._hvp_approach.build_eval(subsample_inputs)
                self.conjugate_grad(g, Hx, inputs, extra_inputs)
                num_batch += 1
            print("max batch achieved {:}".format(num_batch))
            grad_sum /= 1.0 * num_batch
            if self._verbose:
                progbar.update(batch[0].shape[0])
            logger.record_tabular('gdist', LA.norm(grad_sum - g_mean_tilde))

            cur_w = np.copy(self._target.get_param_values(trainable=True))
            w_tilde = self._target_tilde.get_param_values(trainable=True)
            self._target_tilde.set_param_values(cur_w, trainable=True)
            logger.record_tabular('wnorm', LA.norm(cur_w))
            logger.record_tabular('w_dist',
                                  LA.norm(cur_w - w_tilde) / LA.norm(cur_w))

            if self._verbose:
                if progbar.active:
                    progbar.stop()

            if abs(LA.norm(cur_w - w_tilde) /
                   LA.norm(cur_w)) < self._tolerance:
                break