def optimize(self, inputs, extra_inputs=None, subsample_grouped_inputs=None): if len(inputs) == 0: raise NotImplementedError f_loss = self._opt_fun["f_loss"] f_grad = self._opt_fun["f_grad"] f_var_grad = self._opt_fun["f_var_grad"] inputs = tuple(inputs) if extra_inputs is None: extra_inputs = tuple() else: extra_inputs = tuple(extra_inputs) param = np.copy( self._target.get_mean_network().get_param_values(trainable=True)) logger.log( "Start SVRG CG subsample optimization: #parameters: %d, #inputs: %d, #subsample_inputs: %d" % (len(param), len( inputs[0]), self._subsample_factor * len(inputs[0]))) subsamples = BatchDataset(inputs, int(self._subsample_factor * len(inputs[0])), extra_inputs=extra_inputs) dataset = BatchDataset(inputs, self._batch_size, extra_inputs=extra_inputs) for epoch in range(self._max_epochs): if self._verbose: logger.log("Epoch %d" % (epoch)) progbar = pyprind.ProgBar(len(inputs[0])) num_batch = 0 while num_batch < self._max_batch: batch = dataset.random_batch() subsample_inputs = subsamples.random_batch() g = f_grad(*(batch)) mean_Hx = self._mean_hvp.build_eval(subsample_inputs) self._target_network = self._target.get_mean_network() self.conjugate_grad(g, mean_Hx, inputs, extra_inputs) # update var_network weights var_g = f_var_grad(*(batch)) var_Hx = self._var_hvp.build_eval(subsample_inputs) self._target_network = self._target.get_std_network() self.conjugate_grad(var_g, var_Hx, inputs, extra_inputs) num_batch += 1 print("max batch achieved {:}".format(num_batch)) if self._verbose: progbar.update(batch[0].shape[0]) if self._verbose: if progbar.active: progbar.stop()
def optimize(self, inputs, extra_inputs=None, subsample_grouped_inputs=None): if len(inputs) == 0: raise NotImplementedError f_loss = self._opt_fun["f_loss"] f_grad = self._opt_fun["f_grad"] inputs = tuple(inputs) if extra_inputs is None: extra_inputs = tuple() else: extra_inputs = tuple(extra_inputs) param = np.copy(self._target.get_param_values(trainable=True)) logger.log("Start SVRG CG subsample optimization: #parameters: %d, #inputs: %d, #subsample_inputs: %d" % ( len(param), len(inputs[0]), self._subsample_factor * len(inputs[0]))) subsamples = BatchDataset( inputs, int(self._subsample_factor * len(inputs[0])), extra_inputs=extra_inputs) dataset = BatchDataset( inputs, self._batch_size, extra_inputs=extra_inputs) for epoch in range(self._max_epochs): if self._verbose: logger.log("Epoch %d" % (epoch)) progbar = pyprind.ProgBar(len(inputs[0])) # g_u = 1/n \sum_{b} \partial{loss(w_tidle, b)} {w_tidle} print("-------------mini-batch-------------------") num_batch = 0 while num_batch < self._max_batch: batch = dataset.random_batch() # todo, pick mini-batch with weighted prob. g = f_grad(*(batch)) subsample_inputs = subsamples.random_batch() Hx = self._hvp_approach.build_eval(subsample_inputs) self.conjugate_grad(g, Hx, inputs, extra_inputs) num_batch += 1 print("max batch achieved {:}".format(num_batch)) if self._verbose: progbar.update(batch[0].shape[0]) cur_w = np.copy(self._target.get_param_values(trainable=True)) logger.record_tabular('wnorm', LA.norm(cur_w)) if self._verbose: if progbar.active: progbar.stop()
def optimize(self, inputs, extra_inputs=None, callback=None): if len(inputs) == 0: raise NotImplementedError f_loss = self._opt_fun["f_loss"] f_grad = self._opt_fun["f_grad"] f_var_grad = self._opt_fun["f_var_grad"] inputs = tuple(inputs) if extra_inputs is None: extra_inputs = tuple() else: extra_inputs = tuple(extra_inputs) dataset = BatchDataset(inputs, self._batch_size, extra_inputs=extra_inputs) mean_w = self._target.get_mean_network().get_param_values( trainable=True) var_w = self._target.get_std_network().get_param_values(trainable=True) for epoch in range(self._max_epochs): if self._verbose: logger.log("Epoch %d" % (epoch)) progbar = pyprind.ProgBar(len(inputs[0])) num_batch = 0 loss = f_loss(*(tuple(inputs)) + extra_inputs) while num_batch < self._max_batch: batch = dataset.random_batch() g = f_grad(*(batch)) # w = w - \eta g # pdb.set_trace() mean_w = mean_w - self._learning_rate * g self._target.get_mean_network().set_param_values( mean_w, trainable=True) # pdb.set_trace() g_var = f_var_grad(*(batch)) var_w = var_w - self._learning_rate * g_var self._target.get_std_network().set_param_values(var_w, trainable=True) new_loss = f_loss(*(tuple(inputs) + extra_inputs)) print("mean: batch {:} grad {:}, weight {:}".format( num_batch, LA.norm(g), LA.norm(mean_w))) print( "var: batch {:}, loss {:}, diff loss{:}, grad {:}, weight {:}" .format(num_batch, new_loss, new_loss - loss, LA.norm(g_var), LA.norm(var_w))) loss = new_loss num_batch += 1
def optimize(self, inputs, extra_inputs=None, subsample_grouped_inputs=None): if len(inputs) == 0: raise NotImplementedError f_loss = self._opt_fun["f_loss"] f_grad = self._opt_fun["f_grad"] f_grad_tilde = self._opt_fun["f_grad_tilde"] inputs = tuple(inputs) if extra_inputs is None: extra_inputs = tuple() else: extra_inputs = tuple(extra_inputs) param = np.copy(self._target.get_param_values(trainable=True)) logger.log( "Start SVRG CG subsample optimization: #parameters: %d, #inputs: %d, #subsample_inputs: %d" % (len(param), len( inputs[0]), self._subsample_factor * len(inputs[0]))) subsamples = BatchDataset(inputs, int(self._subsample_factor * len(inputs[0])), extra_inputs=extra_inputs) dataset = BatchDataset(inputs, self._batch_size, extra_inputs=extra_inputs) for epoch in range(self._max_epochs): if self._verbose: logger.log("Epoch %d" % (epoch)) progbar = pyprind.ProgBar(len(inputs[0])) # g_u = 1/n \sum_{b} \partial{loss(w_tidle, b)} {w_tidle} grad_sum = np.zeros_like(param) g_mean_tilde = sliced_fun(f_grad_tilde, self._num_slices)(inputs, extra_inputs) logger.record_tabular('g_mean_tilde', LA.norm(g_mean_tilde)) print("-------------mini-batch-------------------") num_batch = 0 while num_batch < self._max_batch: batch = dataset.random_batch() # todo, pick mini-batch with weighted prob. if self._use_SGD: g = f_grad(*(batch)) else: g = f_grad(*(batch)) - \ f_grad_tilde(*(batch)) + g_mean_tilde grad_sum += g subsample_inputs = subsamples.random_batch() pdb.set_trace() Hx = self._hvp_approach.build_eval(subsample_inputs) self.conjugate_grad(g, Hx, inputs, extra_inputs) num_batch += 1 print("max batch achieved {:}".format(num_batch)) grad_sum /= 1.0 * num_batch if self._verbose: progbar.update(batch[0].shape[0]) logger.record_tabular('gdist', LA.norm(grad_sum - g_mean_tilde)) cur_w = np.copy(self._target.get_param_values(trainable=True)) w_tilde = self._target_tilde.get_param_values(trainable=True) self._target_tilde.set_param_values(cur_w, trainable=True) logger.record_tabular('wnorm', LA.norm(cur_w)) logger.record_tabular('w_dist', LA.norm(cur_w - w_tilde) / LA.norm(cur_w)) if self._verbose: if progbar.active: progbar.stop() if abs(LA.norm(cur_w - w_tilde) / LA.norm(cur_w)) < self._tolerance: break