Esempio n. 1
0
    def fid_err_grad_wrapper(self, *args):
        """
        Get the gradient of the fidelity error with respect to all of the
        variables, i.e. the ctrl amplidutes in each timeslot

        This is called by generic optimisation algorithm as the gradients of
        func to the minimised wrt the variables. The argument is the current
        variable values, i.e. control amplitudes, passed as
        a flat array. Hence these are reshaped as [nTimeslots, n_ctrls]
        and then used to update the stored ctrl values (if they have changed)

        Although the optimisation algorithms have a check within them for
        function convergence, i.e. local minima, the sum of the squares
        of the normalised gradient is checked explicitly, and the
        optimisation is terminated if this is below the min_gradient_norm
        condition
        """
        # *** update stats ***
        if self.stats is not None:
            self.stats.num_grad_func_calls += 1
            if self.log_level <= logging.DEBUG:
                logger.debug("gradient call {}".format(
                    self.stats.num_grad_func_calls))
        amps = args[0].copy().reshape(self.dynamics.ctrl_amps.shape)
        self.dynamics.update_ctrl_amps(amps)
        fid_comp = self.dynamics.fid_computer
        # gradient_norm_func is a pointer to the function set in the config
        # that returns the normalised gradients
        grad = fid_comp.get_fid_err_gradient()

        if self._grad_norm_fpath is not None:
            fh = open(self._grad_norm_fpath, 'a')
            fh.write("{:<10n}{:14.6g}\n".format(
                self.stats.num_grad_func_calls, fid_comp.grad_norm))
            fh.close()

        if self.config.test_out_grad:
            # save gradients to file
            dyn = self.dynamics
            fname = "grad_{}_{}_{}_{}_call{}{}".format(
                self.id_text,
                dyn.id_text,
                dyn.prop_computer.id_text,
                dyn.fid_computer.id_text,
                self.stats.num_grad_func_calls,
                self.config.test_out_f_ext)

            fpath = os.path.join(self.config.test_out_dir, fname)
            np.savetxt(fpath, grad, fmt='%11.4g')

        tc = self.termination_conditions
        if fid_comp.grad_norm < tc.min_gradient_norm:
            raise errors.GradMinReachedTerminate(fid_comp.grad_norm)
        return grad.flatten()
Esempio n. 2
0
    def fid_err_grad_wrapper(self, *args):
        """
        Get the gradient of the fidelity error with respect to all of the
        variables, i.e. the ctrl amplidutes in each timeslot

        This is called by generic optimisation algorithm as the gradients of
        func to the minimised wrt the variables. The argument is the current
        variable values, i.e. control amplitudes, passed as
        a flat array. Hence these are reshaped as [nTimeslots, n_ctrls]
        and then used to update the stored ctrl values (if they have changed)

        Although the optimisation algorithms have a check within them for
        function convergence, i.e. local minima, the sum of the squares
        of the normalised gradient is checked explicitly, and the
        optimisation is terminated if this is below the min_gradient_norm
        condition
        """
        # *** update stats ***
        self.num_grad_func_calls += 1
        if self.stats is not None:
            self.stats.num_grad_func_calls = self.num_grad_func_calls
            if self.log_level <= logging.DEBUG:
                logger.debug("gradient call {}".format(
                    self.stats.num_grad_func_calls))
        amps = self._get_ctrl_amps(args[0].copy())
        self.dynamics.update_ctrl_amps(amps)
        fid_comp = self.dynamics.fid_computer
        # gradient_norm_func is a pointer to the function set in the config
        # that returns the normalised gradients
        grad = fid_comp.get_fid_err_gradient()

        if self.iter_summary:
            self.iter_summary.grad_func_call_num = self.num_grad_func_calls
            self.iter_summary.grad_norm = fid_comp.grad_norm

        if self.dump:
            if self.dump.dump_grad_norm:
                self.dump.update_grad_norm_log(fid_comp.grad_norm)

            if self.dump.dump_grad:
                self.dump.update_grad_log(grad)

        tc = self.termination_conditions
        if fid_comp.grad_norm < tc.min_gradient_norm:
            raise errors.GradMinReachedTerminate(fid_comp.grad_norm)
        return grad.flatten()