def norm_project(self, y, c):
     """ Project y using norm A on the convex set bounded by c. """
     if np.any(np.isnan(y)) or np.all(np.absolute(y) <= c):
         return y
     y_norm= np.max(np.absolute(y))
     #print(y_norm)
     solution = y/y_norm*c
     return solution
예제 #2
0
def cost(params, inputs, outputs):
    r"""Calculates the cost on the whole 
        training dataset.
    
    Args:
        params (obj:`jnp.ndarray`): parameter vectors 
            :math:`\vec{\theta}, \vec{\phi}, 
            \vec{\omega}`
        inputs (obj:`jnp.ndarray`): input kets 
            :math:`|\psi_{i} \rangle`in the dataset
        outputs (obj:`jnp.ndarray`): output kets 
            :math:`U(\vec{\theta}, \vec{\phi}, 
            \vec{\omega})|ket_{input} \rangle` 
            in the dataset
    
    Returns:
        float: cost (evaluated on the entire dataset)
            of parametrizing :math:`U(\vec{\theta}, 
            \vec{\phi}, \vec{\omega})` with `params`                  
    """
    loss = 0.0
    thetas, phis, omegas = params
    unitary = Unitary(N)(thetas, phis, omegas)
    for k in range(train_len):
        pred = jnp.dot(unitary, inputs[k])
        loss += jnp.absolute(jnp.real(jnp.dot(outputs[k].conjugate().T, pred)))

    loss = 1 - (1 / train_len) * loss
    return loss[0][0]
예제 #3
0
파일: exact.py 프로젝트: FermiQ/netket
    def _reset(sampler, machine, parameters, state):
        pdf = jnp.absolute(
            to_array(sampler.hilbert, machine, parameters) ** sampler.machine_pow
        )
        pdf = pdf / pdf.sum()

        return state.replace(pdf=pdf)
예제 #4
0
def sigmoid(x):
    """
        Return the activation after a sigmoid function

        Args:
            x (numpy.dtype): The input sum for the activation function.

        Returns:
            The activation value.
    """
    return 1 / (1 + jnp.absolute(jnp.exp(-x)))
예제 #5
0
def computeTrackLength(eta):

    L0 = 108. - 4.4  #max track length in cm. tracker radius - first pixel layer

    tantheta = 2 / (np.exp(eta) - np.exp(-eta))
    r = 267. * tantheta  #267 cm: z position of the outermost disk of the TEC
    L = np.where(
        np.absolute(eta) <= 1.4, L0, (np.where(eta > 1.4,
                                               np.minimum(r, 108.) - 4.4,
                                               np.minimum(-r, 108.) - 4.4)))

    return L0 / L
예제 #6
0
def cond_fun(maxiter, bound, feasStop, state):
    logging.info('compiling cond_fun')
    counter = state[1]
    dual_objective, primal_dual_gap, maxfeasible = state[7:10]

    cond1 = counter <= maxiter
    cond2 = dual_objective < bound
    cond3 = np.logical_or(
        np.absolute(primal_dual_gap) > 1e-6,
        np.logical_or(maxfeasible > feasStop,
                      np.logical_and(counter < 200, maxfeasible >= 0)))
    return np.logical_and(cond1, np.logical_and(cond2, cond3))
def ae(y_pred, y_true):
    ''' Description: mean-absolut-error loss
        Args:
            y_pred : value predicted by method
            y_true : ground truth value
            eps: some scalar
    '''
    a = np.absolute(np.array([y_pred]) - np.array([y_true]))[0]
    if (
            a.shape == (1, )
    ):  #y_pred is sometimes not a scalar but a (1,) vector which causes problems. This does fix the problem.
        return a[0]
    return a
예제 #8
0
 def norm_project(self, y, A, c):
     """ Project y using norm A on the convex set bounded by c. """
     if np.any(np.isnan(y)) or np.all(np.absolute(y) <= c):
         return y
     y_shape = y.shape
     y_reshaped = np.ravel(y)
     dim_y = y_reshaped.shape[0]
     P = matrix(self.numpyify(A))
     q = matrix(self.numpyify(-np.dot(A, y_reshaped)))
     G = matrix(self.numpyify(np.append(np.identity(dim_y), -np.identity(dim_y), axis=0)), tc='d')
     h = matrix(self.numpyify(np.repeat(c, 2 * dim_y)), tc='d')
     solution = np.array(onp.array(solvers.qp(P, q, G, h)['x'])).squeeze().reshape(y_shape)
     return solution
예제 #9
0
    def _L1(self):
        """Compute L1 regularization.

        :formula: .. math:: L_{1} = \\sum_{i=1}^{L} |\\beta|

        :return: L1 penalty
        :rtype: :obj:`float`
        """
        if self.L1_reg != 0.:
            for layer in self._layers:
                # L1 norm ; one regularization option is to enforce L1 norm to
                # be small
                self.L1 += jnp.sum(jnp.absolute(layer.params['w']))
        return self.L1
예제 #10
0
def cost(phi, theta, omega, ket):
    r"""Returns the fidelity between the evolved state and :math: `|0 \rangle` state
    
    Parameters:
    ----------
        phi: int/float 
             rotation angle for the second rotation around z axis 
             
        theta: int/float
             rotation angle around y axis
             
        omega: int/float
             rotation angle for the first rotation around z axis
             
        ket: array[complex] 
             array representing the ket to be acted upon by by the rotation matrix

    Returns:
    -------
        float: 
        fidelity between the :math:`|0 \rangle` state and the evolved state under rotation    
    """
    evolved = jnp.dot(rot(phi, theta, omega), ket)
    return jnp.absolute(jnp.vdot(evolved.T, basis(2, 0).full()))
    def update(self, params, x, y, loss=None):
        """
        Description: Updates parameters based on correct value, loss and learning rate.
        Args:
            params (list/numpy.ndarray): Parameters of method pred method
            x (float): input to method
            y (float): true label
            loss (function): loss function. defaults to input value.
        Returns:
            Updated parameters in same shape as input
        """
        assert self.initialized
        assert type(
            params
        ) == dict, "optimizers can only take params in dictionary format"

        grad = self.gradient(params, x, y,
                             loss=loss)  # defined in optimizers core class

        if self.theta is None:
            self.theta = {
                k: -dw
                for (k, w), dw in zip(params.items(), grad.values())
            }
        else:
            self.theta = {
                k: v - dw
                for (k, v), dw in zip(self.theta.items(), grad.values())
            }

        if self.eta is None:
            self.eta = {
                k: dw * dw
                for (k, w), dw in zip(params.items(), grad.values())
            }
        else:
            self.eta = {
                k: v + dw * dw
                for (k, v), dw in zip(self.eta.items(), grad.values())
            }

        if self.theta_max is None:
            self.theta_max = {
                k: np.absolute(v)
                for (k, v) in self.theta.items()
            }
        else:
            self.theta_max = {
                k: np.where(np.greater(np.absolute(v), v_max), np.absolute(v),
                            v_max)
                for (k, v), v_max in zip(self.theta.items(),
                                         self.theta_max.values())
            }

        new_params = {
            k: np.where(np.equal(0.0, np.maximum(theta_max, eta)), theta,
                        theta / np.sqrt(np.maximum(theta_max, eta)))
            for (k, w), theta, theta_max, eta in zip(params.items(
            ), self.theta.values(), self.theta_max.values(), self.eta.values())
        }

        x_new = np.roll(x, 1)
        x_new = jax.ops.index_update(x_new, 0, y)
        y_t = self.pred(params=new_params, x=x_new)

        #        print('y before {0}'.format(y_t))
        x_plus_bias_new = np.vstack((np.ones((1, 1)), x_new))
        new_mapped_params = {
            k: self.norm_project(
                np.where(np.equal(0.0, np.maximum(theta_max, eta)), 0.0,
                         1.0 / np.sqrt(np.maximum(theta_max, eta))),
                x_plus_bias_new, y_t, p)
            for (k, p), theta_max, eta in zip(
                new_params.items(), self.theta_max.values(), self.eta.values())
        }

        #        y_t = self.pred(params=new_mapped_params, x=x_new)
        #        print('y after {0}'.format(y_t))
        return new_mapped_params
예제 #12
0
def absolute(x):
  if isinstance(x, JaxArray): x = x.value
  return JaxArray(jnp.absolute(x))
예제 #13
0
 def soft_thresholding(z, threshold):
     return jnp.sign(z) * jnp.maximum(jnp.absolute(z) - threshold, 0.0)
예제 #14
0
 def loss(W):
     return (jnp.absolute(C @ W)**
             2).sum() / 2  # added absolute for complex support
예제 #15
0
def train(loss,
          params,
          lr,
          validation_loss=None,
          optimizer='adam',
          nsteps=1001,
          verbose=False,
          log=False,
          tol=1e-5):
    """Train a model using the provided loss and parameters.

  Args:
    loss: Loss function.
    params: Parameter dictionary.
    lr: Learning Rate.
    validation_loss: Optional validation loss, for checkpoints.
    optimizer: One of 'sgd', 'adam'.
    nsteps: Number of iterations of SGD.
    verbose: If true, log the loss at every 100 iterations.
    log: If true, populate checkpoints every 100 iterations.
    tol: If the l_infinity norm of the gradient is smaller than this value, then
      assume convergence.

  Returns:
    Learned params and checkpoints

  """
    if optimizer == 'adam':
        tx = optax.adam(learning_rate=lr)
    elif optimizer == 'sgd':
        tx = optax.sgd(learning_rate=lr)
    else:
        raise ValueError(
            f'Expected "adam" or "sgd" for optimizer, got {optimizer}')

    opt_state = tx.init(params)
    loss_grad_fn = jax.value_and_grad(loss)

    checkpoints = []
    for i in range(nsteps):
        loss_val, grads = loss_grad_fn(params)

        if i % 100 == 0:
            if log:
                b, w = jax.tree_leaves(params)
                param_dict = {
                    'Step': i,
                    'Train Loss': loss_val.item(),
                    'b': b.item(),
                }
                if validation_loss is not None:
                    iter_validation_loss = validation_loss(params).item()
                    param_dict['Validation Loss'] = iter_validation_loss

                for j, this_w in enumerate(w):
                    param_dict.update({f'w{j+1}': this_w.item()})

                checkpoints.append(param_dict)

            if verbose:
                logging.info('Step %d: Train Loss %f', i, loss_val)
                if validation_loss is not None:
                    logging.info('Step %d: Validation Loss %f', i,
                                 iter_validation_loss)

        _, w_grad = jax.tree_leaves(grads)
        if jnp.max(jnp.absolute(w_grad)) < tol:
            logging.info('Converged at Step %d', i)
            break

        updates, opt_state = tx.update(grads, opt_state)
        params = optax.apply_updates(params, updates)

    # Log the final parameters
    if log:
        loss_val, grads = loss_grad_fn(params)
        b, w = jax.tree_leaves(params)
        param_dict = {
            'Step': i,
            'Train Loss': loss_val.item(),
            'b': b.item(),
        }
        if validation_loss is not None:
            iter_validation_loss = validation_loss(params).item()
            param_dict['Validation Loss'] = iter_validation_loss

        for j, this_w in enumerate(w):
            param_dict.update({f'w{j+1}': this_w.item()})

        checkpoints.append(param_dict)

    return params, checkpoints
예제 #16
0
def weight_magnitude(weights):
    """Creates weight magnitude-based saliencies, given a weight matrix."""
    return jnp.absolute(weights)
예제 #17
0
weights = normal(key=key, shape=(3, ))
init_ket = basis(2, 1).full()
der_cost = grad(cost, argnums=[0, 1, 2])
state_hist = []

for epoch in range(epochs):
    iters = 0
    diff = 1
    tol = 1e-7
    while jnp.all(diff > tol) and iters < max_iters:
        prev_weights = weights
        der = jnp.asarray(der_cost(*prev_weights.T, init_ket))
        weights = weights + alpha * der
        state_hist.append(Qobj(onp.dot(rot(*weights), init_ket)))
        iters += 1
        diff = jnp.absolute(weights - prev_weights)
    fidel = cost(*weights.T, init_ket)
    progress = [epoch + 1, fidel]
    if (epoch) % 1 == 0:
        print("Epoch: {:2f} | Fidelity: {:3f}".format(*jnp.asarray(progress)))

# ## Bloch Sphere Visualization
#
# As we see above, we started off with a very low fidelity (~0.26). With gradient descent iterations, we progressively achieve better fidelities via better parameters, $\phi$, $\theta$, and $\omega$. To see it visually, we render our states on to a Bloch sphere.
#
# We see how our optimizer (Gradient Descent in this case) finds a (nearly) optimal path to walk from $|1 \rangle$ (green arrow pointing exactly south) to very close to the target state $|0 \rangle$ (brown arrow pointing exactly north), as desired.

b = Bloch()
b.add_states(Qobj(init_ket))
b.add_states(basis(2, 0))
for state in range(0, len(state_hist), 6):
예제 #18
0
 def log_det_fn(x):
   x = jnp.asarray(x)
   jac_scalar = jac_fn(x.reshape(-1))
   log_det_ = jnp.log(jnp.absolute(jac_scalar))
   return log_det_.reshape(x.shape)