Exemplo n.º 1
0
    def fit(self,
            data,
            epochs=100,
            pos_batch_size=100,
            neg_batch_size=None,
            k=1,
            lr=1e-3,
            input_bases=None,
            progbar=False,
            starting_epoch=1,
            time=False,
            callbacks=None,
            optimizer=torch.optim.SGD,
            **kwargs):
        """Train the WaveFunction.

        :param data: The training samples
        :type data: np.array
        :param epochs: The number of full training passes through the dataset.
                       Technically, this specifies the index of the *last* training
                       epoch, which is relevant if `starting_epoch` is being set.
        :type epochs: int
        :param pos_batch_size: The size of batches for the positive phase
                               taken from the data.
        :type pos_batch_size: int
        :param neg_batch_size: The size of batches for the negative phase
                               taken from the data. Defaults to `pos_batch_size`.
        :type neg_batch_size: int
        :param k: The number of contrastive divergence steps.
        :type k: int
        :param lr: Learning rate
        :type lr: float
        :param input_bases: The measurement bases for each sample. Must be provided
                            if training a ComplexWaveFunction.
        :type input_bases: np.array
        :param progbar: Whether or not to display a progress bar. If "notebook"
                        is passed, will use a Jupyter notebook compatible
                        progress bar.
        :type progbar: bool or str
        :param starting_epoch: The epoch to start from. Useful if continuing training
                               from a previous state.
        :type starting_epoch: int
        :param callbacks: Callbacks to run while training.
        :type callbacks: list[qucumber.callbacks.CallbackBase]
        :param optimizer: The constructor of a torch optimizer.
        :type optimizer: torch.optim.Optimizer
        :param kwargs: Keyword arguments to pass to the optimizer
        """
        if self.stop_training:  # terminate immediately if stop_training is true
            return

        disable_progbar = progbar is False
        progress_bar = tqdm_notebook if progbar == "notebook" else tqdm

        callbacks = CallbackList(callbacks if callbacks else [])
        if time:
            callbacks.append(Timer())

        neg_batch_size = neg_batch_size if neg_batch_size else pos_batch_size

        if isinstance(data, torch.Tensor):
            train_samples = (data.clone().detach().to(device=self.device,
                                                      dtype=torch.double))
        else:
            train_samples = torch.tensor(data,
                                         device=self.device,
                                         dtype=torch.double)

        if len(self.networks) > 1:
            all_params = [
                getattr(self, net).parameters() for net in self.networks
            ]
            all_params = list(chain(*all_params))
            optimizer = optimizer(all_params, lr=lr, **kwargs)
        else:
            optimizer = optimizer(self.rbm_am.parameters(), lr=lr, **kwargs)

        if input_bases is not None:
            z_samples = extract_refbasis_samples(
                train_samples, input_bases).to(device=self.device)
        else:
            z_samples = None

        callbacks.on_train_start(self)

        num_batches = ceil(train_samples.shape[0] / pos_batch_size)
        for ep in progress_bar(range(starting_epoch, epochs + 1),
                               desc="Epochs ",
                               disable=disable_progbar):
            data_iterator = self._shuffle_data(
                pos_batch_size,
                neg_batch_size,
                num_batches,
                train_samples,
                input_bases,
                z_samples,
            )
            callbacks.on_epoch_start(self, ep)

            for b, batch in enumerate(data_iterator):
                callbacks.on_batch_start(self, ep, b)

                all_grads = self.compute_batch_gradients(k, *batch)

                optimizer.zero_grad()  # clear any cached gradients

                # assign gradients to corresponding parameters
                for i, net in enumerate(self.networks):
                    rbm = getattr(self, net)
                    vector_to_grads(all_grads[i], rbm.parameters())

                optimizer.step()  # tell the optimizer to apply the gradients

                callbacks.on_batch_end(self, ep, b)
                if self.stop_training:  # check for stop_training signal
                    break

            callbacks.on_epoch_end(self, ep)
            if self.stop_training:  # check for stop_training signal
                break

        callbacks.on_train_end(self)
Exemplo n.º 2
0
    def fit(
        self,
        data,
        input_bases,
        target=None,
        epochs=100,
        pos_batch_size=100,
        neg_batch_size=None,
        k=1,
        lr=1,
        progbar=False,
        starting_epoch=1,
        callbacks=None,
        time=False,
        optimizer=torch.optim.Adadelta,
        scheduler=torch.optim.lr_scheduler.MultiStepLR,
        lr_drop_epoch=50,
        lr_drop_factor=1.0,
        bases=None,
        train_to_fid=False,
        track_fid=False,
        **kwargs,
    ):
        r"""Trains the density matrix

        :param data: The training samples
        :type data: numpy.ndarray
        :param input_bases: The measurement bases for each sample
        :type input_bases: numpy.ndarray
        :param target: The density matrix you are trying to train towards
        :type target: torch.Tensor
        :param epochs: The number of epochs to train for
        :type epochs: int
        :param pos_batch_size: The size of batches for the positive phase
        :type pos_batch_size: int
        :param neg_batch_size: The size of batches for the negative phase
        :type neg_batch_size: int
        :param k: Number of contrastive divergence steps
        :type k: int
        :param lr: Learning rate - different meaning depending on optimizer!
        :type lr: float
        :param progbar: Whether or note to use a progress bar. Pass "notebook"
                        for a Jupyter notebook-friendly version
        :type progbar: bool or str
        :param starting_epoch: The epoch to start from
        :type starting_epoch: int
        :param callbacks: Callbacks to run while training
        :type callbacks: list[qucumber.callbacks.CallbackBase]
        :param optimizer: The constructor of a torch optimizer
        :type optimizer: torch.optim.Optimizer
        :param scheduler: The constructor of a torch scheduler
        :param lr_drop_epoch: The epoch, or list of epochs, at which the
                              base learning rate is dropped
        :type lr_drop_epoch: int or list[int]
        :param lr_drop_factor: The factor by which the scheduler will decrease the
                               learning after the prescribed number of steps
        :type lr_drop_factor: float
        :param bases: All bases in which a measurement is made. Used to check gradients
        :type bases: numpy.ndarray
        :param train_to_fid: Instructs the RBM to end training prematurely if the
                             specified fidelity is reached. If it is never reached,
                             training will continue until the specified epoch
        :type train_to_fid: float or bool
        :param track_fid: A file to which to write fidelity at every epoch.
                          Useful for keeping track of training run in background
        :type track_fid: str or bool
        """
        disable_progbar = progbar is False
        progress_bar = tqdm_notebook if progbar == "notebook" else tqdm
        lr_drop_epoch = ([lr_drop_epoch]
                         if isinstance(lr_drop_epoch, int) else lr_drop_epoch)

        callbacks = CallbackList(callbacks if callbacks else [])
        if time:
            callbacks.append(Timer())

        train_samples = data.clone().detach().double().to(device=self.device)

        neg_batch_size = neg_batch_size if neg_batch_size else pos_batch_size

        all_params = [getattr(self, net).parameters() for net in self.networks]
        all_params = list(chain(*all_params))

        optimizer = optimizer(all_params, lr=lr, **kwargs)
        scheduler = scheduler(optimizer, lr_drop_epoch, gamma=lr_drop_factor)

        z_samples = extract_refbasis_samples(train_samples, input_bases)

        num_batches = ceil(train_samples.shape[0] / pos_batch_size)

        # here for now to test shit

        callbacks.on_train_start(self)

        for ep in progress_bar(range(starting_epoch, epochs + 1),
                               desc="Epochs ",
                               disable=disable_progbar):

            data_iterator = self._shuffle_data(
                pos_batch_size,
                neg_batch_size,
                num_batches,
                train_samples,
                input_bases,
                z_samples,
            )
            callbacks.on_epoch_start(self, ep)

            for b, batch in enumerate(data_iterator):
                callbacks.on_batch_start(self, ep, b)
                all_grads = self.compute_batch_gradients(k, *batch)
                optimizer.zero_grad()

                for i, net in enumerate(self.networks):

                    rbm = getattr(self, net)
                    vector_to_grads(all_grads[i], rbm.parameters())

                optimizer.step()

                callbacks.on_batch_end(self, ep, b)

            callbacks.on_epoch_end(self, ep)

            scheduler.step()

            if train_to_fid or track_fid:
                v_space = self.generate_hilbert_space(self.num_visible)
                fidel = ts.density_matrix_fidelity(self, target, v_space)

            if track_fid:
                f = open(track_fid, "a")
                f.write(f"Epoch: {ep}\tFidelity: {fidel}\n")
                f.close()

            if train_to_fid:
                if fidel >= train_to_fid:
                    print("\n\nTarget fidelity of", train_to_fid,
                          "reached or exceeded!")
                    break

        callbacks.on_train_end(self)