def optimize(self, inputs, extra_inputs=None, callback=None): if not inputs: # Assumes that we should always sample mini-batches raise NotImplementedError if self._opt_fun is None: raise Exception( 'Use update_opt() to setup the loss function first.') f_loss = self._opt_fun['f_loss'] if extra_inputs is None: extra_inputs = tuple() last_loss = f_loss(*(tuple(inputs) + extra_inputs)) start_time = time.time() dataset = BatchDataset(inputs, self._batch_size, extra_inputs=extra_inputs) sess = tf.compat.v1.get_default_session() for epoch in range(self._max_epochs): if self._verbose: logger.log('Epoch {}'.format(epoch)) progbar = pyprind.ProgBar(len(inputs[0])) for batch in dataset.iterate(update=True): sess.run(self._train_op, dict(list(zip(self._input_vars, batch)))) if self._verbose: progbar.update(len(batch[0])) if self._verbose: if progbar.active: progbar.stop() new_loss = f_loss(*(tuple(inputs) + extra_inputs)) if self._verbose: logger.log('Epoch: {} | Loss: {}'.format(epoch, new_loss)) if self._callback or callback: elapsed = time.time() - start_time callback_args = dict( loss=new_loss, params=self._target.get_param_values( trainable=True) if self._target else None, itr=epoch, elapsed=elapsed, ) if self._callback: self._callback(callback_args) if callback: callback(**callback_args) if abs(last_loss - new_loss) < self._tolerance: break last_loss = new_loss
def get_minibatch(self, *inputs): r"""Yields a batch of inputs. Notes: P is the size of minibatch (self._minibatch_size) Args: *inputs (list[torch.Tensor]): A list of inputs. Each input has shape :math:`(N \dot [T], *)`. Yields: list[torch.Tensor]: A list batch of inputs. Each batch has shape :math:`(P, *)`. """ batch_dataset = BatchDataset(inputs, self._minibatch_size) for _ in range(self._max_optimization_epochs): for dataset in batch_dataset.iterate(): yield dataset
def optimize(self, inputs, extra_inputs=None, callback=None): """Perform optimization. Args: inputs (list[numpy.ndarray]): List of input values. extra_inputs (list[numpy.ndarray]): List of extra input values. callback (callable): Function to call during each epoch. Default is None. Raises: NotImplementedError: If inputs are invalid. Exception: If loss function is None, i.e. not defined. """ if not inputs: # Assumes that we should always sample mini-batches raise NotImplementedError('No inputs are fed to optimizer.') if self._opt_fun is None: raise Exception( 'Use update_opt() to setup the loss function first.') f_loss = self._opt_fun['f_loss'] if extra_inputs is None: extra_inputs = tuple() last_loss = f_loss(*(tuple(inputs) + extra_inputs)) start_time = time.time() dataset = BatchDataset(inputs, self._batch_size, extra_inputs=extra_inputs) sess = tf.compat.v1.get_default_session() for epoch in range(self._max_optimization_epochs): if self._verbose: logger.log('Epoch {}'.format(epoch)) with click.progressbar(length=len(inputs[0]), label='Optimizing minibatches') as pbar: for batch in dataset.iterate(update=True): sess.run(self._train_op, dict(list(zip(self._input_vars, batch)))) pbar.update(len(batch[0])) new_loss = f_loss(*(tuple(inputs) + extra_inputs)) if self._verbose: logger.log('Epoch: {} | Loss: {}'.format(epoch, new_loss)) if self._callback or callback: elapsed = time.time() - start_time callback_args = dict( loss=new_loss, params=self._target.get_param_values() if self._target else None, itr=epoch, elapsed=elapsed, ) if self._callback: self._callback(callback_args) if callback: callback(**callback_args) if abs(last_loss - new_loss) < self._tolerance: break last_loss = new_loss