def fit(self, data, epochs=100, pos_batch_size=100, neg_batch_size=None, k=1, lr=1e-3, input_bases=None, progbar=False, starting_epoch=1, time=False, callbacks=None, optimizer=torch.optim.SGD, **kwargs): """Train the WaveFunction. :param data: The training samples :type data: np.array :param epochs: The number of full training passes through the dataset. Technically, this specifies the index of the *last* training epoch, which is relevant if `starting_epoch` is being set. :type epochs: int :param pos_batch_size: The size of batches for the positive phase taken from the data. :type pos_batch_size: int :param neg_batch_size: The size of batches for the negative phase taken from the data. Defaults to `pos_batch_size`. :type neg_batch_size: int :param k: The number of contrastive divergence steps. :type k: int :param lr: Learning rate :type lr: float :param input_bases: The measurement bases for each sample. Must be provided if training a ComplexWaveFunction. :type input_bases: np.array :param progbar: Whether or not to display a progress bar. If "notebook" is passed, will use a Jupyter notebook compatible progress bar. :type progbar: bool or str :param starting_epoch: The epoch to start from. Useful if continuing training from a previous state. :type starting_epoch: int :param callbacks: Callbacks to run while training. :type callbacks: list[qucumber.callbacks.CallbackBase] :param optimizer: The constructor of a torch optimizer. :type optimizer: torch.optim.Optimizer :param kwargs: Keyword arguments to pass to the optimizer """ if self.stop_training: # terminate immediately if stop_training is true return disable_progbar = progbar is False progress_bar = tqdm_notebook if progbar == "notebook" else tqdm callbacks = CallbackList(callbacks if callbacks else []) if time: callbacks.append(Timer()) neg_batch_size = neg_batch_size if neg_batch_size else pos_batch_size if isinstance(data, torch.Tensor): train_samples = (data.clone().detach().to(device=self.device, dtype=torch.double)) else: train_samples = torch.tensor(data, device=self.device, dtype=torch.double) if len(self.networks) > 1: all_params = [ getattr(self, net).parameters() for net in self.networks ] all_params = list(chain(*all_params)) optimizer = optimizer(all_params, lr=lr, **kwargs) else: optimizer = optimizer(self.rbm_am.parameters(), lr=lr, **kwargs) if input_bases is not None: z_samples = extract_refbasis_samples( train_samples, input_bases).to(device=self.device) else: z_samples = None callbacks.on_train_start(self) num_batches = ceil(train_samples.shape[0] / pos_batch_size) for ep in progress_bar(range(starting_epoch, epochs + 1), desc="Epochs ", disable=disable_progbar): data_iterator = self._shuffle_data( pos_batch_size, neg_batch_size, num_batches, train_samples, input_bases, z_samples, ) callbacks.on_epoch_start(self, ep) for b, batch in enumerate(data_iterator): callbacks.on_batch_start(self, ep, b) all_grads = self.compute_batch_gradients(k, *batch) optimizer.zero_grad() # clear any cached gradients # assign gradients to corresponding parameters for i, net in enumerate(self.networks): rbm = getattr(self, net) vector_to_grads(all_grads[i], rbm.parameters()) optimizer.step() # tell the optimizer to apply the gradients callbacks.on_batch_end(self, ep, b) if self.stop_training: # check for stop_training signal break callbacks.on_epoch_end(self, ep) if self.stop_training: # check for stop_training signal break callbacks.on_train_end(self)
def fit( self, data, input_bases, target=None, epochs=100, pos_batch_size=100, neg_batch_size=None, k=1, lr=1, progbar=False, starting_epoch=1, callbacks=None, time=False, optimizer=torch.optim.Adadelta, scheduler=torch.optim.lr_scheduler.MultiStepLR, lr_drop_epoch=50, lr_drop_factor=1.0, bases=None, train_to_fid=False, track_fid=False, **kwargs, ): r"""Trains the density matrix :param data: The training samples :type data: numpy.ndarray :param input_bases: The measurement bases for each sample :type input_bases: numpy.ndarray :param target: The density matrix you are trying to train towards :type target: torch.Tensor :param epochs: The number of epochs to train for :type epochs: int :param pos_batch_size: The size of batches for the positive phase :type pos_batch_size: int :param neg_batch_size: The size of batches for the negative phase :type neg_batch_size: int :param k: Number of contrastive divergence steps :type k: int :param lr: Learning rate - different meaning depending on optimizer! :type lr: float :param progbar: Whether or note to use a progress bar. Pass "notebook" for a Jupyter notebook-friendly version :type progbar: bool or str :param starting_epoch: The epoch to start from :type starting_epoch: int :param callbacks: Callbacks to run while training :type callbacks: list[qucumber.callbacks.CallbackBase] :param optimizer: The constructor of a torch optimizer :type optimizer: torch.optim.Optimizer :param scheduler: The constructor of a torch scheduler :param lr_drop_epoch: The epoch, or list of epochs, at which the base learning rate is dropped :type lr_drop_epoch: int or list[int] :param lr_drop_factor: The factor by which the scheduler will decrease the learning after the prescribed number of steps :type lr_drop_factor: float :param bases: All bases in which a measurement is made. Used to check gradients :type bases: numpy.ndarray :param train_to_fid: Instructs the RBM to end training prematurely if the specified fidelity is reached. If it is never reached, training will continue until the specified epoch :type train_to_fid: float or bool :param track_fid: A file to which to write fidelity at every epoch. Useful for keeping track of training run in background :type track_fid: str or bool """ disable_progbar = progbar is False progress_bar = tqdm_notebook if progbar == "notebook" else tqdm lr_drop_epoch = ([lr_drop_epoch] if isinstance(lr_drop_epoch, int) else lr_drop_epoch) callbacks = CallbackList(callbacks if callbacks else []) if time: callbacks.append(Timer()) train_samples = data.clone().detach().double().to(device=self.device) neg_batch_size = neg_batch_size if neg_batch_size else pos_batch_size all_params = [getattr(self, net).parameters() for net in self.networks] all_params = list(chain(*all_params)) optimizer = optimizer(all_params, lr=lr, **kwargs) scheduler = scheduler(optimizer, lr_drop_epoch, gamma=lr_drop_factor) z_samples = extract_refbasis_samples(train_samples, input_bases) num_batches = ceil(train_samples.shape[0] / pos_batch_size) # here for now to test shit callbacks.on_train_start(self) for ep in progress_bar(range(starting_epoch, epochs + 1), desc="Epochs ", disable=disable_progbar): data_iterator = self._shuffle_data( pos_batch_size, neg_batch_size, num_batches, train_samples, input_bases, z_samples, ) callbacks.on_epoch_start(self, ep) for b, batch in enumerate(data_iterator): callbacks.on_batch_start(self, ep, b) all_grads = self.compute_batch_gradients(k, *batch) optimizer.zero_grad() for i, net in enumerate(self.networks): rbm = getattr(self, net) vector_to_grads(all_grads[i], rbm.parameters()) optimizer.step() callbacks.on_batch_end(self, ep, b) callbacks.on_epoch_end(self, ep) scheduler.step() if train_to_fid or track_fid: v_space = self.generate_hilbert_space(self.num_visible) fidel = ts.density_matrix_fidelity(self, target, v_space) if track_fid: f = open(track_fid, "a") f.write(f"Epoch: {ep}\tFidelity: {fidel}\n") f.close() if train_to_fid: if fidel >= train_to_fid: print("\n\nTarget fidelity of", train_to_fid, "reached or exceeded!") break callbacks.on_train_end(self)