def optimize(self, inputs, extra_inputs=None): if extra_inputs is None: extra_inputs = list() # import ipdb; ipdb.set_trace() dataset = BatchDataset(inputs=inputs, batch_size=self._batch_size, extra_inputs=extra_inputs) cg_dataset = BatchDataset(inputs=inputs, batch_size=self._cg_batch_size, extra_inputs=extra_inputs) itr = [0] start_time = time.time() if self._callback: def opt_callback(): loss = self._opt_fun["f_loss"](*(inputs + extra_inputs)) elapsed = time.time() - start_time self._callback(dict( loss=loss, params=self._target.get_param_values(trainable=True), itr=itr[0], elapsed=elapsed, )) itr[0] += 1 else: opt_callback = None self._hf_optimizer.train( gradient_dataset=dataset, cg_dataset=cg_dataset, itr_callback=opt_callback, num_updates=self._max_opt_itr, preconditioner=True, verbose=True )
def optimize(self, inputs, extra_inputs=None, callback=None): if len(inputs) == 0: # Assumes that we should always sample mini-batches raise NotImplementedError f_loss = self._opt_fun["f_loss"] if extra_inputs is None: extra_inputs = tuple() last_loss = f_loss(*(tuple(inputs) + extra_inputs)) start_time = time.time() dataset = BatchDataset(inputs, self._batch_size, extra_inputs=extra_inputs) sess = tf.get_default_session() for epoch in range(self._max_epochs): if self._verbose: logger.log("Epoch %d" % (epoch)) progbar = pyprind.ProgBar(len(inputs[0])) gc.collect() for batch in dataset.iterate(update=True): sess.run(self._train_op, dict(list(zip(self._input_vars, batch)))) if self._verbose: progbar.update(len(batch[0])) gc.collect() if self._verbose: if progbar.active: progbar.stop() new_loss = f_loss(*(tuple(inputs) + extra_inputs)) if self._verbose: logger.log("Epoch: %d | Loss: %f" % (epoch, new_loss)) if self._callback or callback: elapsed = time.time() - start_time callback_args = dict( loss=new_loss, params=self._target.get_param_values( trainable=True) if self._target else None, itr=epoch, elapsed=elapsed, ) if self._callback: self._callback(callback_args) if callback: callback(**callback_args) if abs(last_loss - new_loss) < self._tolerance: break last_loss = new_loss
def optimize(self, inputs, extra_inputs=None, callback=None): if len(inputs) == 0: # Assumes that we should always sample mini-batches raise NotImplementedError f_loss = self._opt_fun["f_loss"] if extra_inputs is None: extra_inputs = tuple() last_loss = f_loss(*(tuple(inputs) + extra_inputs)) start_time = time.time() dataset = BatchDataset(inputs, self._batch_size, extra_inputs=extra_inputs) sess = tf.get_default_session() for epoch in range(self._max_epochs): if self._verbose: logger.log("Epoch %d" % (epoch)) progbar = pyprind.ProgBar(len(inputs[0])) for batch in dataset.iterate(update=True): if self._init_train_op is not None: sess.run(self._init_train_op, dict(list(zip(self._input_vars, batch)))) self._init_train_op = None # only use it once else: sess.run(self._train_op, dict(list(zip(self._input_vars, batch)))) if self._verbose: progbar.update(len(batch[0])) if self._verbose: if progbar.active: progbar.stop() new_loss = f_loss(*(tuple(inputs) + extra_inputs)) if self._verbose: logger.log("Epoch: %d | Loss: %f" % (epoch, new_loss)) if self._callback or callback: elapsed = time.time() - start_time callback_args = dict( loss=new_loss, params=self._target.get_param_values(trainable=True) if self._target else None, itr=epoch, elapsed=elapsed, ) if self._callback: self._callback(callback_args) if callback: callback(**callback_args) if abs(last_loss - new_loss) < self._tolerance: break last_loss = new_loss
def optimize_gen(self, inputs, extra_inputs=None, callback=None, yield_itr=None): if len(inputs) == 0: # Assumes that we should always sample mini-batches raise NotImplementedError f_opt = self._opt_fun["f_opt"] f_loss = self._opt_fun["f_loss"] if extra_inputs is None: extra_inputs = tuple() last_loss = f_loss(*(tuple(inputs) + extra_inputs)) start_time = time.time() dataset = BatchDataset(inputs, self._batch_size, extra_inputs=extra_inputs #, randomized=self._randomized ) itr = 0 for epoch in pyprind.prog_bar(list(range(self._max_epochs))): for batch in dataset.iterate(update=True): f_opt(*batch) if yield_itr is not None and (itr % (yield_itr + 1)) == 0: yield itr += 1 new_loss = self.loss(inputs, extra_inputs) if self._verbose: logger.log("Epoch %d, loss %s" % (epoch, new_loss)) if self._callback or callback: elapsed = time.time() - start_time callback_args = dict( loss=new_loss, params=self._target.get_param_values( trainable=True) if self._target else None, itr=epoch, elapsed=elapsed, ) if self._callback: self._callback(callback_args) if callback: callback(**callback_args) if abs(last_loss - new_loss) < self._tolerance: break last_loss = new_loss
def optimize(self, inputs, extra_inputs=None, callback=None): if len(inputs) == 0: raise NotImplementedError f_loss = self._opt_fun["f_loss"] f_grad = self._opt_fun["f_grad"] f_var_grad = self._opt_fun["f_var_grad"] inputs = tuple(inputs) if extra_inputs is None: extra_inputs = tuple() else: extra_inputs = tuple(extra_inputs) dataset = BatchDataset(inputs, self._batch_size, extra_inputs=extra_inputs) mean_w = self._target.get_mean_network().get_param_values( trainable=True) var_w = self._target.get_std_network().get_param_values(trainable=True) for epoch in range(self._max_epochs): if self._verbose: logger.log("Epoch %d" % (epoch)) progbar = pyprind.ProgBar(len(inputs[0])) num_batch = 0 loss = f_loss(*(tuple(inputs)) + extra_inputs) while num_batch < self._max_batch: batch = dataset.random_batch() g = f_grad(*(batch)) # w = w - \eta g # pdb.set_trace() mean_w = mean_w - self._learning_rate * g self._target.get_mean_network().set_param_values( mean_w, trainable=True) # pdb.set_trace() g_var = f_var_grad(*(batch)) var_w = var_w - self._learning_rate * g_var self._target.get_std_network().set_param_values(var_w, trainable=True) new_loss = f_loss(*(tuple(inputs) + extra_inputs)) print("mean: batch {:} grad {:}, weight {:}".format( num_batch, LA.norm(g), LA.norm(mean_w))) print( "var: batch {:}, loss {:}, diff loss{:}, grad {:}, weight {:}" .format(num_batch, new_loss, new_loss - loss, LA.norm(g_var), LA.norm(var_w))) loss = new_loss num_batch += 1
def optimize_gen(self, inputs, extra_inputs=None, callback=None, yield_itr=None): if len(inputs) == 0: # Assumes that we should always sample mini-batches raise NotImplementedError f_opt = self._opt_fun["f_opt"] f_loss = self._opt_fun["f_loss"] if extra_inputs is None: extra_inputs = tuple() last_loss = f_loss(*(tuple(inputs) + extra_inputs)) start_time = time.time() dataset = BatchDataset( inputs, self._batch_size, extra_inputs=extra_inputs #, randomized=self._randomized ) itr = 0 for epoch in pyprind.prog_bar(list(range(self._max_epochs))): for batch in dataset.iterate(update=True): f_opt(*batch) if yield_itr is not None and (itr % (yield_itr+1)) == 0: yield itr += 1 new_loss = f_loss(*(tuple(inputs) + extra_inputs)) if self._verbose: logger.log("Epoch %d, loss %s" % (epoch, new_loss)) if self._callback or callback: elapsed = time.time() - start_time callback_args = dict( loss=new_loss, params=self._target.get_param_values(trainable=True) if self._target else None, itr=epoch, elapsed=elapsed, ) if self._callback: self._callback(callback_args) if callback: callback(**callback_args) if abs(last_loss - new_loss) < self._tolerance: break last_loss = new_loss
def optimize(self, inputs, extra_inputs=None, callback=None): if len(inputs) == 0: # Assumes that we should always sample mini-batches raise NotImplementedError f_opt = self._opt_fun["f_opt"] f_loss = self._opt_fun["f_loss"] if extra_inputs is None: extra_inputs = tuple() last_loss = f_loss(*(tuple(inputs) + extra_inputs)) start_time = time.time() dataset = BatchDataset(inputs, self._batch_size, extra_inputs=extra_inputs) for epoch in xrange(self._max_epochs): if self._verbose: logger.log("Epoch %d" % epoch) for batch in dataset.iterate(update=True): f_opt(*batch) new_loss = f_loss(*(tuple(inputs) + extra_inputs)) if self._callback or callback: elapsed = time.time() - start_time callback_args = dict( loss=new_loss, params=self._target.get_param_values( trainable=True) if self._target else None, itr=epoch, elapsed=elapsed, ) if self._callback: self._callback(callback_args) if callback: callback(**callback_args) if abs(last_loss - new_loss) < self._tolerance: break last_loss = new_loss
def optimize(self, inputs, extra_inputs=None, callback=None): if len(inputs) == 0: # Assumes that we should always sample mini-batches raise NotImplementedError f_opt = self._opt_fun["f_opt"] f_loss = self._opt_fun["f_loss"] if extra_inputs is None: extra_inputs = tuple() last_loss = f_loss(*(tuple(inputs) + extra_inputs)) start_time = time.time() dataset = BatchDataset(inputs, self._batch_size, extra_inputs=extra_inputs) for epoch in range(self._max_epochs): if self._verbose: logger.log("Epoch %d" % epoch) for batch in dataset.iterate(update=True): f_opt(*batch) new_loss = f_loss(*(tuple(inputs) + extra_inputs)) if self._callback or callback: elapsed = time.time() - start_time callback_args = dict( loss=new_loss, params=self._target.get_param_values(trainable=True) if self._target else None, itr=epoch, elapsed=elapsed, ) if self._callback: self._callback(callback_args) if callback: callback(**callback_args) if abs(last_loss - new_loss) < self._tolerance: break last_loss = new_loss
def optimize(self, inputs, extra_inputs=None, subsample_grouped_inputs=None): if len(inputs) == 0: raise NotImplementedError f_loss = self._opt_fun["f_loss"] f_grad = self._opt_fun["f_grad"] f_var_grad = self._opt_fun["f_var_grad"] inputs = tuple(inputs) if extra_inputs is None: extra_inputs = tuple() else: extra_inputs = tuple(extra_inputs) param = np.copy( self._target.get_mean_network().get_param_values(trainable=True)) logger.log( "Start SVRG CG subsample optimization: #parameters: %d, #inputs: %d, #subsample_inputs: %d" % (len(param), len( inputs[0]), self._subsample_factor * len(inputs[0]))) subsamples = BatchDataset(inputs, int(self._subsample_factor * len(inputs[0])), extra_inputs=extra_inputs) dataset = BatchDataset(inputs, self._batch_size, extra_inputs=extra_inputs) for epoch in range(self._max_epochs): if self._verbose: logger.log("Epoch %d" % (epoch)) progbar = pyprind.ProgBar(len(inputs[0])) num_batch = 0 while num_batch < self._max_batch: batch = dataset.random_batch() subsample_inputs = subsamples.random_batch() g = f_grad(*(batch)) mean_Hx = self._mean_hvp.build_eval(subsample_inputs) self._target_network = self._target.get_mean_network() self.conjugate_grad(g, mean_Hx, inputs, extra_inputs) # update var_network weights var_g = f_var_grad(*(batch)) var_Hx = self._var_hvp.build_eval(subsample_inputs) self._target_network = self._target.get_std_network() self.conjugate_grad(var_g, var_Hx, inputs, extra_inputs) num_batch += 1 print("max batch achieved {:}".format(num_batch)) if self._verbose: progbar.update(batch[0].shape[0]) if self._verbose: if progbar.active: progbar.stop()
def optimize(self, inputs, extra_inputs=None, subsample_grouped_inputs=None): if len(inputs) == 0: raise NotImplementedError f_loss = self._opt_fun["f_loss"] f_grad = self._opt_fun["f_grad"] inputs = tuple(inputs) if extra_inputs is None: extra_inputs = tuple() else: extra_inputs = tuple(extra_inputs) param = np.copy(self._target.get_param_values(trainable=True)) logger.log("Start SVRG CG subsample optimization: #parameters: %d, #inputs: %d, #subsample_inputs: %d" % ( len(param), len(inputs[0]), self._subsample_factor * len(inputs[0]))) subsamples = BatchDataset( inputs, int(self._subsample_factor * len(inputs[0])), extra_inputs=extra_inputs) dataset = BatchDataset( inputs, self._batch_size, extra_inputs=extra_inputs) for epoch in range(self._max_epochs): if self._verbose: logger.log("Epoch %d" % (epoch)) progbar = pyprind.ProgBar(len(inputs[0])) # g_u = 1/n \sum_{b} \partial{loss(w_tidle, b)} {w_tidle} print("-------------mini-batch-------------------") num_batch = 0 while num_batch < self._max_batch: batch = dataset.random_batch() # todo, pick mini-batch with weighted prob. g = f_grad(*(batch)) subsample_inputs = subsamples.random_batch() Hx = self._hvp_approach.build_eval(subsample_inputs) self.conjugate_grad(g, Hx, inputs, extra_inputs) num_batch += 1 print("max batch achieved {:}".format(num_batch)) if self._verbose: progbar.update(batch[0].shape[0]) cur_w = np.copy(self._target.get_param_values(trainable=True)) logger.record_tabular('wnorm', LA.norm(cur_w)) if self._verbose: if progbar.active: progbar.stop()
def optimize(self, inputs, extra_inputs=None, callback=None, val_inputs=[None], val_extra_inputs=tuple([None])): if len(inputs) == 0: # Assumes that we should always sample mini-batches raise NotImplementedError assert len(inputs) == 1 dataset_size, _ = inputs[0].shape f_loss = self._opt_fun["f_loss"] # Plot individual costs from complexity / likelihood terms. try: use_c_loss = True c_loss = self._opt_fun["c_loss"] except KeyError: use_c_loss = False try: use_l_loss = True l_loss = self._opt_fun["l_loss"] except KeyError: use_l_loss = False if extra_inputs is None: extra_inputs = tuple() #last_loss = f_loss(*(tuple(inputs) + extra_inputs)) start_time = time.time() train_dataset = BatchDataset(inputs, self._batch_size, extra_inputs=extra_inputs) if not all([vi is None for vi in val_inputs]): val_dataset = BatchDataset(val_inputs, self._batch_size, extra_inputs=val_extra_inputs) sess = tf.get_default_session() for epoch in range(self._max_epochs): if self._verbose: logger.log("Epoch %d" % (epoch)) progbar = pyprind.ProgBar(len(inputs[0])) train_losses = [] train_c_losses, train_l_losses = [], [] # batch is a (matrix X, matrix Y) tuple for t, batch in enumerate(train_dataset.iterate(update=True)): sess.run(self._train_op, dict(list(zip(self._input_vars, batch)))) train_losses.append(f_loss(*batch)) train_c_losses.append(c_loss(*batch)) train_l_losses.append(l_loss(*batch)) if self._verbose: progbar.update(len(batch[0])) train_loss = np.mean(train_losses) train_c_loss = np.mean(train_c_losses) train_l_loss = np.mean(train_l_losses) val_losses = [] if not all([vi is None for vi in val_inputs]): for t, batch in enumerate(val_dataset.iterate(update=True)): val_losses.append(f_loss(*batch)) val_loss = np.mean(val_losses) for interval, op in self._update_priors_ops: if interval != 0 and epoch % interval == 0: sess.run(op) if self._verbose: if progbar.active: progbar.stop() if self._verbose: logger.log("Epoch: %d | Loss: %f" % (epoch, train_loss)) if self._callback or callback: elapsed = time.time() - start_time callback_args = dict( loss=train_loss, params=self._target.get_param_values( trainable=True) if self._target else None, itr=epoch, elapsed=elapsed, ) if use_c_loss: callback_args['c_loss'] = train_c_loss if use_l_loss: callback_args['l_loss'] = train_l_loss if val_loss is not None: callback_args.update({'val_loss': val_loss}) if self._callback: self._callback(callback_args) if callback: callback(**callback_args)
def optimize(self, inputs, extra_inputs=None, subsample_grouped_inputs=None): if len(inputs) == 0: raise NotImplementedError f_loss = self._opt_fun["f_loss"] f_grad = self._opt_fun["f_grad"] f_grad_tilde = self._opt_fun["f_grad_tilde"] inputs = tuple(inputs) if extra_inputs is None: extra_inputs = tuple() else: extra_inputs = tuple(extra_inputs) param = np.copy(self._target.get_param_values(trainable=True)) logger.log( "Start SVRG CG subsample optimization: #parameters: %d, #inputs: %d, #subsample_inputs: %d" % (len(param), len( inputs[0]), self._subsample_factor * len(inputs[0]))) subsamples = BatchDataset(inputs, int(self._subsample_factor * len(inputs[0])), extra_inputs=extra_inputs) dataset = BatchDataset(inputs, self._batch_size, extra_inputs=extra_inputs) for epoch in range(self._max_epochs): if self._verbose: logger.log("Epoch %d" % (epoch)) progbar = pyprind.ProgBar(len(inputs[0])) # g_u = 1/n \sum_{b} \partial{loss(w_tidle, b)} {w_tidle} grad_sum = np.zeros_like(param) g_mean_tilde = sliced_fun(f_grad_tilde, self._num_slices)(inputs, extra_inputs) logger.record_tabular('g_mean_tilde', LA.norm(g_mean_tilde)) print("-------------mini-batch-------------------") num_batch = 0 while num_batch < self._max_batch: batch = dataset.random_batch() # todo, pick mini-batch with weighted prob. if self._use_SGD: g = f_grad(*(batch)) else: g = f_grad(*(batch)) - \ f_grad_tilde(*(batch)) + g_mean_tilde grad_sum += g subsample_inputs = subsamples.random_batch() pdb.set_trace() Hx = self._hvp_approach.build_eval(subsample_inputs) self.conjugate_grad(g, Hx, inputs, extra_inputs) num_batch += 1 print("max batch achieved {:}".format(num_batch)) grad_sum /= 1.0 * num_batch if self._verbose: progbar.update(batch[0].shape[0]) logger.record_tabular('gdist', LA.norm(grad_sum - g_mean_tilde)) cur_w = np.copy(self._target.get_param_values(trainable=True)) w_tilde = self._target_tilde.get_param_values(trainable=True) self._target_tilde.set_param_values(cur_w, trainable=True) logger.record_tabular('wnorm', LA.norm(cur_w)) logger.record_tabular('w_dist', LA.norm(cur_w - w_tilde) / LA.norm(cur_w)) if self._verbose: if progbar.active: progbar.stop() if abs(LA.norm(cur_w - w_tilde) / LA.norm(cur_w)) < self._tolerance: break