コード例 #1
0
def test_stimulate():
    shape = (800, 800)
    
    A = fk.stimulus.protocol(start=0, duration=2, period=50)
    A = fk.stimulus.linear(shape, direction="up", coverage=.05, modulus=1., protocol=A)
    B = fk.stimulus.protocol(start=10, duration=2, period=50)
    B = fk.stimulus.triangular(shape, direction="left", angle=30, coverage=0.5, modulus=1., protocol=B)
    C = fk.stimulus.protocol(start=30, duration=2, period=1000000)
    C = fk.stimulus.rectangular(shape, (50, 50), (1, 1), modulus=1., protocol=C)    
    stimuli = [A, B, C]
    
    A_is_active_at = set([0, 1, 50, 51, 100, 101, 150, 151, 200, 201, 250, 251])
    B_is_active_at = set([10, 11, 60, 61, 110, 111, 160, 161, 210, 211, 260, 261])
    C_is_active_at = set([30, 31])
    times = set(range(0, 300))
    
    X = np.zeros(shape, dtype="float32")
    for t in times:
        X = fk.model.stimulate(t, X, stimuli)
        if t in A_is_active_at or t in B_is_active_at or t in C_is_active_at:
            # stimulus active
            assert np.sum(np.nonzero(X)) != 0, "Failed stimulus test at time {}, nonzero is {}".format(t, np.nonzero(X))
        else:
            # stimulus non active
            assert np.sum(np.nonzero(X)) == 0, "Failed stimulus test at time {}, nonzero is {}".format(t, np.nonzero(X))
        X = np.zeros(shape, dtype="float32")
    return
コード例 #2
0
ファイル: sigma.py プロジェクト: IPL-UV/jaxkern
def estimate_sigma_median_kth(X: np.ndarray,
                              Y: np.ndarray,
                              percent: float = 0.3) -> float:
    """Estimates the sigma using the median kth distance

    This calculates the sigma value using the kth percent
    of the distances. THe median value of that is the
    new sigma value.

    Parameters
    ----------
    dists : jax.numpy.ndarray
        the distance matrix already calculate (n_samples, n_samples)

    k : int
        the kth value from the (default=0.15)

    Returns
    -------
    kth_dist : jax.numpy.ndarray
        the neighbours up to the kth distance
    """

    # find the kth distance
    dists = _estimate_sigma_kth(X=X, Y=Y, percent=percent)

    # median distances
    sigma = np.median(dists[np.nonzero(dists)])
    return sigma
コード例 #3
0
    def _cluster(self, observations, targets):
        '''
        Arranges the observations so that the number of the observations belonging
        to each class is equal and separates from each other.

        Parameters
        ----------
        observations : array
            Dataset

        targets : array
            The class labels of the given dataset

        Returns
        -------
        * array
            The new dataset that has one additional axis to separate the observations
            of different classes
        '''
        clusters = []
        min_n_sample = float('inf')

        for c in range(self.num_of_classes):
            obs_of_same_class = observations[jnp.nonzero(targets == c)]
            n_obs = obs_of_same_class.shape[0]
            min_n_sample = min(min_n_sample, n_obs)
            clusters.append(obs_of_same_class.reshape((n_obs, -1)))

        return jnp.vstack([
            obs_of_same_class[jnp.newaxis, 0:min_n_sample, :]
            for obs_of_same_class in clusters
        ])
コード例 #4
0
ファイル: sparse.py プロジェクト: yuanqing-wang/dgl
def _reduce_grad(grad, shape):
    """Reduce gradient on the broadcast dimension
    If there is broadcast in forward pass, gradients need to be reduced on
    broadcast dimension. This function checks the input tensor shape and
    gradient shape and perform the reduction.

    Parameters
    ----------
    grad: Tensor
        Gradient tensor
    shape: tuple
        Shape of input tensor

    Returns
    -------
    Tensor
    """
    grad_shape = grad.shape[1:]
    in_shape = shape[1:]
    if in_shape == grad_shape:
        # no need to reduce
        return grad
    num_to_squeeze = len(grad_shape) - len(in_shape)
    # pad inshape
    in_shape = (1,) * num_to_squeeze + in_shape
    reduce_idx = jnp.nonzero(jnp.tensor(grad_shape) - jnp.tensor(in_shape))
    reduce_idx += 1  # skip batch dim
    if len(reduce_idx) > 0:
        grad = grad.sum(dim=tuple(reduce_idx), keepdim=True)
    return grad.view(-1, *shape[1:])
コード例 #5
0
ファイル: sigma.py プロジェクト: IPL-UV/jaxkern
def estimate_sigma_median(X: np.ndarray, Y: np.ndarray) -> float:
    """Estimate sigma using the median distance

    Parameters
    ----------
    X : jax.numpy.ndarray
        input data (n_samples, n_features)
    Y : jax.numpy.ndarray
        input data (n_samples, n_features)

    Returns
    -------
    sigma : float
        the estimated sigma
    """
    # compute distance matrix
    dists = pdist_squareform(X, Y)

    # remove non-zero elements
    # dists = dists[np.nonzero(dists)]

    # get the median value
    sigma = np.median(dists[np.nonzero(dists)])

    return sigma
コード例 #6
0
def twiddle_factor_perm(n, stride):
    """The indices in a n x n matrix that marks where the entries of a butterfly factors are.
    """
    # TODO: The logic here is more complicated than necessary
    # I don't have time rn to find a simpler way
    factor = jnp.arange(1, 1 + 2 * n).reshape(n // 2, 2, 2)
    matrix_flat = twiddle_factor_to_matrix(factor, stride).flatten()
    nonzero_locs, = jnp.nonzero(matrix_flat)
    perm = nonzero_locs[jnp.argsort(matrix_flat[nonzero_locs])]
    return perm
コード例 #7
0
ファイル: sparse_ops.py プロジェクト: paul-tqh-nguyen/jax
def _coo_fromdense_impl(mat, *, nnz, index_dtype):
    mat = jnp.asarray(mat)
    assert mat.ndim == 2

    row, col = jnp.nonzero(mat, size=nnz)
    data = mat[row, col]

    true_nonzeros = jnp.arange(nnz) < (mat != 0).sum()
    data = jnp.where(true_nonzeros, data, 0)

    return data, row.astype(index_dtype), col.astype(index_dtype)
コード例 #8
0
    def add_memories(self, batch: Dict[str, Array], predictions: Dict[str,
                                                                      Array]):
        """Save generated memories in-memory storage."""
        mention_mask = batch['mention_target_weights'] > 0
        memory_index_end = min(self.num_total_memories,
                               self.memory_index + mention_mask.sum())
        memory_index_len = memory_index_end - self.memory_index
        indices = self.memory_permutation[self.memory_index:memory_index_end]

        # We might not save mention encodings for all target mentions.
        # First, some of them might are pad or are too close to a passage boundary
        # (in these cases we assume that `mention_target_weights` = 0).
        # Second, there might be more mentions then we actually need
        # since we limit the number of the total memories by `num_total_memories`.
        # Therefore, we create `mention_index` to select a subset of mentions,
        # which encodings we are planning to save.
        mention_index_0, mention_index_1 = jnp.nonzero(mention_mask)
        mention_index_0 = mention_index_0[:memory_index_len]
        mention_index_1 = mention_index_1[:memory_index_len]
        mention_index = (mention_index_0, mention_index_1)

        self.memory_embeddings[indices] = predictions['values'][mention_index]
        if self.memory_key_embeddings is not None:
            self.memory_key_embeddings[indices] = predictions['keys'][
                mention_index]
        self.memory_labels[indices] = batch['mention_target_ids'][
            mention_index]
        self.memory_text_hashes[indices] = batch['target_text_identifiers'][
            mention_index]
        self.memory_mention_hashes[indices] = batch['target_mention_hashes'][
            mention_index]

        # Convert to global batch positions
        n_devices, batch_size, _ = batch['text_ids'].shape
        mention_target_batch_positions = batch[
            'mention_target_batch_positions']
        mention_target_batch_positions = (
            mention_target_batch_positions +
            np.expand_dims(np.arange(n_devices), 1) * batch_size)
        mention_target_batch_positions = mention_target_batch_positions[
            mention_index]

        self.text_entities[indices] = batch['unique_mention_ids'].reshape(
            n_devices * batch_size, -1)[mention_target_batch_positions]
        self.text_ids[indices] = batch['text_ids'].reshape(
            n_devices * batch_size, -1)[mention_target_batch_positions]
        self.start_end_positions[
            indices,
            0] = batch['mention_target_start_positions'][mention_index]
        self.start_end_positions[
            indices, 1] = batch['mention_target_end_positions'][mention_index]
        self.memory_index = memory_index_end
コード例 #9
0
ファイル: sparse_ops.py プロジェクト: paul-tqh-nguyen/jax
def _csr_fromdense_impl(mat, *, nnz, index_dtype):
    mat = jnp.asarray(mat)
    assert mat.ndim == 2
    m = mat.shape[0]

    row, col = jnp.nonzero(mat, size=nnz)
    data = mat[row, col]

    true_nonzeros = jnp.arange(nnz) < (mat != 0).sum()
    data = jnp.where(true_nonzeros, data, 0)
    row = jnp.where(true_nonzeros, row, m)
    indices = col.astype(index_dtype)
    indptr = jnp.zeros(m + 1, dtype=index_dtype).at[1:].set(
        jnp.cumsum(jnp.bincount(row, length=m)))
    return data, indices, indptr
コード例 #10
0
ファイル: maf.py プロジェクト: jxzhangjhu/NuX
            def inverse(z):
                x = jnp.zeros_like(z)

                # We need to build output a dimension at a time
                def carry_body(carry, inputs):
                    x, idx = carry, inputs
                    mu, alpha = made(x, rng)
                    w = mu + z * jnp.exp(alpha)
                    x = jax.ops.index_update(x, idx, w[idx])
                    return x, alpha[idx]

                indices = jnp.nonzero(
                    input_sel == (1 + jnp.arange(x.shape[0])[:, None]))[1]
                x, alpha_diag = jax.lax.scan(carry_body, x, indices)
                log_det = -alpha_diag.sum(axis=-1)
                return x, log_det
コード例 #11
0
ファイル: _adaptive.py プロジェクト: google/deluca
    def __call__(self, x, A, B):

        play_i = np.argmax(self.weights)
        self.u = self.learners[play_i].get_action(x)

        # Update alive models
        for i in jnp.nonzero(self.alive)[0]:
            loss_i = self.policy_loss(self.learners[i], A, B, x, self.w)
            self.weights[i] *= np.exp(-self.eta * loss_i)
            self.weights[i] = min(max(self.weights[i], self.eps), self.inf)
            self.learners[i].update(x, u=self.u)

        self.t += 1

        # One is born every expert_density steps
        if self.t % self.expert_density == 0:
            self.alive = self.alive.at[self.t].set(1)
            self.weights[self.t] = self.eps
            self.learners[self.t] = self.base_controller(A,
                                                         B,
                                                         cost_fn=self.cost_fn)
            self.learners[self.t].x = x

        # At most one dies
        kill_list = jnp.where(self.tod == self.t)
        if len(kill_list[0]):
            kill = int(kill_list[0][0])
            if self.alive[kill]:
                self.alive = self.alive.at[kill].set(0)
                del self.learners[kill]
                self.weights[kill] = 0

        # Rescale
        max_w = np.max(self.weights)
        if max_w < 1:
            self.weights /= max_w

        # Get new noise (will be located at w[-1])
        self.w = self.w.at[0].set(x - self.A @ self.x + self.B @ self.u)
        self.w = jnp.roll(self.w, -1, axis=0)

        # Update System
        self.x, self.A, self.B = x, A, B

        return self.u
def indexer_and_shape_from_mask(mask):
    # language=rst
    """
    Given a 2d mask array, create an array that can index into a vector with the same number of elements
    as nonzero elements in mask and result in an array of the same size as mask, but with the elements
    specified from the vector. Also return the shape of the resulting array when mask is applied.

    :param mask: 2d boolean mask array
    """
    index = np.zeros_like(mask, dtype=int)
    non_zero_indices = jnp.nonzero(mask)
    index[non_zero_indices] = jnp.arange(len(non_zero_indices[0])) + 1

    nonzero_x, nonzero_y = non_zero_indices
    n_rows = np.unique(nonzero_x).size
    assert nonzero_x.size%n_rows == 0
    n_cols = nonzero_x.size // n_rows
    shape = (n_rows, n_cols)
    return index, shape
コード例 #13
0
def _bcoo_fromdense_impl(mat, *, nse, n_batch, n_dense, index_dtype):
  mat = jnp.asarray(mat)
  mask = (mat != 0)
  if n_dense > 0:
    mask = mask.any([-(i + 1) for i in range(n_dense)])
  nonzero = lambda a: jnp.nonzero(a, size=nse) if a.ndim else ()
  for _ in range(n_batch):
    nonzero = vmap(nonzero, 0)
  indices = nonzero(mask)
  if not indices:
    indices = jnp.zeros(mask.shape[:n_batch] + (0, nse), index_dtype)
  else:
    indices = jnp.moveaxis(jnp.array(indices, index_dtype), 0, n_batch)
  data = bcoo_extract(indices, mat)

  true_nonzeros = jnp.arange(nse) < mask.sum(list(range(n_batch, mask.ndim)))[..., None]
  true_nonzeros = true_nonzeros[(n_batch + 1) * (slice(None),) + n_dense * (None,)]
  data = jnp.where(true_nonzeros, data, 0)

  return data, indices
コード例 #14
0
ファイル: kdtree2.py プロジェクト: pacargile/MINESweeper
    def __build(self, idx, maxes, mins):
        if len(idx) <= self.leafsize:
            return KDTree.leafnode(idx)
        else:
            data = self.data[idx]
            # maxes = np.amax(data,axis=0)
            # mins = np.amin(data,axis=0)
            d = np.argmax(maxes-mins)
            maxval = maxes[d]
            minval = mins[d]
            if maxval == minval:
                # all points are identical; warn user?
                return KDTree.leafnode(idx)
            data = data[:,d]

            # sliding midpoint rule; see Maneewongvatana and Mount 1999
            # for arguments that this is a good idea.
            split = (maxval+minval)/2
            less_idx = np.nonzero(data <= split)[0]
            greater_idx = np.nonzero(data > split)[0]
            if len(less_idx) == 0:
                split = np.amin(data)
                less_idx = np.nonzero(data <= split)[0]
                greater_idx = np.nonzero(data > split)[0]
            if len(greater_idx) == 0:
                split = np.amax(data)
                less_idx = np.nonzero(data < split)[0]
                greater_idx = np.nonzero(data >= split)[0]
            if len(less_idx) == 0:
                # _still_ zero? all must have the same value
                if not np.all(data == data[0]):
                    raise ValueError("Troublesome data array: %s" % data)
                split = data[0]
                less_idx = np.arange(len(data)-1)
                greater_idx = np.array([len(data)-1])

            # lessmaxes = maxes.copy()
            # lessmaxes = index_update(lessmaxes, index[d], split)
            lessmaxes = np.asarray([x if (ii != d) else split for ii,x in enumerate(maxes)])
            # lessmaxes[d] = split
            # greatermins = mins.copy()
            # greatermins = index_update(greatermins, index[d], split)
            greatermins = np.asarray([x if (ii != d) else split for ii,x in enumerate(mins)])
            # greatermins[d] = split
            return KDTree.innernode(d, split,
                    self.__build(idx[less_idx],lessmaxes,mins),
                    self.__build(idx[greater_idx],maxes,greatermins))
コード例 #15
0
def _coo_fromdense_impl(mat, *, nnz, index_dtype):
    mat = jnp.asarray(mat)
    m, n = mat.shape
    mat_flat = jnp.ravel(mat)
    ind = jnp.nonzero(mat_flat, size=nnz)[0].astype(index_dtype)
    return mat_flat[ind], ind // n, ind % n
コード例 #16
0
def flatnonzero(a):
    return jnp.nonzero(jnp.ravel(a))[0]
コード例 #17
0
def rand_argmax(a):
    return np.random.choice(jnp.nonzero(jnp.ravel(a == jnp.max(a)))[0])
コード例 #18
0
ファイル: nonzero.py プロジェクト: gglin001/onnx-jax
 def _nonzero(x):
     return jnp.asarray(jnp.nonzero(x))
コード例 #19
0
 def __getitem__(self, index: int):
     subset_idx = np.nonzero(index >= self.start_indices)[0][-1]
     offset = index - self.start_indices[subset_idx]
     return self._subsets[subset_idx][offset]
コード例 #20
0
def nonzero(x):
  if isinstance(x, JaxArray): x = x.value
  return jnp.nonzero(x)
コード例 #21
0
ファイル: ctc_objectives.py プロジェクト: tensorflow/lingvo
def collapse_and_remove_blanks(labels: jnp.ndarray,
                               seq_length: jnp.ndarray,
                               blank_id: int = 0):
  """Merge repeated labels into single labels and remove the designated blank symbol.

  Args:
    labels: Array of shape (batch, seq_length)
    seq_length: Arrray of shape (batch), sequence length of each batch element.
    blank_id: Optional id of the blank symbol

  Returns:
    tuple of tf.SparseTensor of shape (batch, seq_length) with repeated labels
    collapsed, eg: [[A, A, B, B, A],
                    [A, B, C, D, E]] => [[A, B, A],
                                         [A, B, C, D, E]]
    and int tensor of shape [batch] with new sequence lengths.
  """
  b, t = labels.shape
  # Zap out blank
  blank_mask = 1 - jnp.equal(labels, blank_id)
  labels = (labels * blank_mask).astype(labels.dtype)

  # Mask labels that don't equal previous label.
  label_mask = jnp.concatenate([
      jnp.ones_like(labels[:, :1], dtype=jnp.int32),
      jnp.not_equal(labels[:, 1:], labels[:, :-1])
  ],
                               axis=1)

  # Filter labels that aren't in the original sequence.
  maxlen = labels.shape[1]
  seq_mask = sequence_mask(seq_length, maxlen=maxlen)
  label_mask = label_mask * seq_mask

  # remove repetitions from the labels
  ulabels = label_mask * labels

  # Count masks for new sequence lengths.
  label_mask = jnp.not_equal(ulabels, 0).astype(labels.dtype)
  new_seq_len = jnp.sum(label_mask, axis=1)

  # Mask indexes based on sequence length mask.
  new_maxlen = maxlen
  idx_mask = sequence_mask(new_seq_len, maxlen=new_maxlen)

  # Flatten everything and mask out labels to keep and sparse indices.
  flat_labels = jnp.reshape(ulabels, [-1])
  flat_idx_mask = jnp.reshape(idx_mask, [-1])

  indices = jnp.nonzero(flat_idx_mask, size=b * t)[0]
  values = jnp.nonzero(flat_labels, size=b * t)[0]
  updates = jnp.take_along_axis(flat_labels, values, axis=-1)

  # Scatter to flat shape.
  flat = jnp.zeros(flat_idx_mask.shape).astype(labels.dtype)
  flat = flat.at[indices].set(updates)
  # 0'th position in the flat array gets clobbered by later padded updates,
  # so reset it here to its original value
  flat = flat.at[0].set(updates[0])

  # Reshape back to square batch.
  batch_size = labels.shape[0]
  new_shape = [batch_size, new_maxlen]
  return (jnp.reshape(flat, new_shape).astype(labels.dtype),
          new_seq_len.astype(seq_length.dtype))
コード例 #22
0
ファイル: dynamics.py プロジェクト: Badi96/sam_common
    def solve_direct(self, states, controls, T, homotopy, boundaries):

        # sanity
        assert states.shape[0] == controls.shape[0]
        assert states.shape[1] == self.state_dim
        assert controls.shape[1] == self.control_dim

        # system parameters
        params = self.params.values()

        # number of collocation nodes
        n = states.shape[0]

        # decision vector bounds
        @jit
        def get_bounds():
            zl = np.hstack((self.state_lb, self.control_lb))
            zl = np.tile(zl, n)
            zl = np.hstack(([0.0], zl))
            zu = np.hstack((self.state_ub, self.control_ub))
            zu = np.tile(zu, n)
            zu = np.hstack(([np.inf], zu))
            return zl, zu

        # decision vector maker
        @jit
        def flatten(states, controls, T):
            z = np.hstack((states, controls)).flatten()
            z = np.hstack(([T], z))
            return z

        # decsision vector translator
        @jit
        def unflatten(z):
            T = z[0]
            z = z[1:].reshape(n, self.state_dim + self.control_dim)
            states = z[:, :self.state_dim]
            controls = z[:, self.state_dim:]
            return states, controls, T

        # fitness vector
        print('Compiling fitness...')

        @jit
        def fitness(z):

            # translate decision vector
            states, controls, T = unflatten(z)

            # time grid
            n = states.shape[0]
            times = np.linspace(0, T, n)

            # objective
            L = vmap(lambda state, control: self.lagrangian(
                state, control, homotopy, *params))
            L = L(states, controls)
            J = np.trapz(L, dx=T / (n - 1))

            # Lagrangian state dynamics constraints, and boundary constraints
            # e0 = self.collocate_lagrangian(states, controls, times, costs, homotopy, *params)
            e1 = self.collocate_state(states, controls, times, *params)
            e2, e3 = boundaries(states[0, :], states[-1, :])
            e = np.hstack((e1.flatten(), e2, e3))**2

            # fitness vector
            return np.hstack((J, e))

        # z = flatten(states, controls, T)
        # fitness(z)

        # sparse Jacobian
        print('Compiling Jacobian and its sparsity...')
        gradient = jit(jacfwd(fitness))
        z = flatten(states, controls, T)
        sparse_id = np.vstack((np.nonzero(gradient(z)))).T
        sparse_gradient = jit(lambda z: gradient(z)[[*sparse_id.T]])
        gradient_sparsity = jit(lambda: sparse_id)
        print('Jacobian has {} elements.'.format(sparse_id.shape[0]))

        # assign PyGMO problem methods
        self.fitness = fitness
        self.gradient = sparse_gradient
        self.gradient_sparsity = gradient_sparsity
        self.get_bounds = get_bounds
        self.get_nobj = jit(lambda: 1)
        nec = fitness(z).shape[0] - 1
        self.get_nec = jit(lambda: nec)

        # plot before
        states, controls, T = unflatten(z)
        self.plot('../img/direct_before.png', states, dpi=1000)

        # solve NLP with IPOPT
        print('Solving...')
        prob = pg.problem(udp=self)
        algo = pg.ipopt()
        algo.set_integer_option('max_iter', 1000)
        algo = pg.algorithm(algo)
        algo.set_verbosity(1)
        pop = pg.population(prob=prob, size=0)
        pop.push_back(z)
        pop = algo.evolve(pop)

        # save and plot solution
        z = pop.champion_x
        np.save('decision.npy', z)
        states, controls, T = unflatten(z)
        self.plot('../img/direct_after.png', states, dpi=1000)
コード例 #23
0
def nonzero_1d(input):
    x = (jnp.nonzero(input)[0]).squeeze()
    return x.flatten()
コード例 #24
0
def eval_vasp_xml(file="vasprun.xml", recip=False, norm_fermi=True, print_out=False):
    dft = pymatgen.io.vasp.outputs.Vasprun(file, parse_projected_eigen=False)
    orbital_energy = pd.read_csv("element_orbital_energy.csv").set_index("element")

    lattice = jnp.asarray(dft.get_trajectory().as_dict()['lattice']).squeeze()
    lattice_normed = lattice / jnp.linalg.norm(lattice, axis=1, keepdims=True)
    lattice_recip = jnp.asarray(Lattice(lattice).reciprocal_lattice.matrix)  # wrong!

    positions_base = dft.get_trajectory().as_dict()['base_positions']
    positions = jnp.dot(positions_base, lattice)

    k_points = jnp.asarray(dft.actual_kpoints)

    weights = jnp.asarray(dft.actual_kpoints_weights)  # how to use ?

    species_dict = {}
    species_arr = np.asarray(dft.atomic_symbols)
    count = 0
    print(species_arr)
    for key in dict.fromkeys(set(dft.atomic_symbols), {}):
        species_dict["species_" + Element(key).long_name] = {"symbol": key,
                                                             "number": count,
                                                             "Es": orbital_energy.loc["C", "E_s"],
                                                             "Ep": orbital_energy.loc["C", "E_p"],
                                                             "Ed": orbital_energy.loc["C", "E_d"],
                                                             }
        species_arr[species_arr == key] = count  # cycles through elements but returns correct one anyway
        count += 1
    species_arr = jnp.asarray(species_arr.astype(int))

    for key in dft.eigenvalues.keys():
        key_last = key
    true_inp = np.zeros(
        (dft.eigenvalues[key_last][:, :, 0].shape[0], dft.eigenvalues[key_last][:, :, 0].shape[1], len(dft.eigenvalues.keys())))
    count = 0
    if len(dft.eigenvalues.keys()) != 1:
        print("only one spin direction supported but", len(dft.eigenvalues.keys()), "where given")
    for key in dft.eigenvalues.keys():  # OrderedDictionary might be nice
        true_inp[:, :, count] = dft.eigenvalues[key][:, :, 0]  # what is [:, :, 0] ???????????????????
        occupied = np.max(jnp.nonzero(dft.eigenvalues[key][:, :, 1])[1]) + 1
        fermi = find_fermi(true_inp, occupied)
        count += 1
    if norm_fermi:
        true_inp -= fermi
        print("E fermi calculated normed", find_fermi(true_inp, occupied, plot=False))
    if print_out:
        print("Lattice", type(lattice), lattice.shape, "\n", lattice)
        print("Lattice Normed", type(lattice_normed), lattice_normed.shape, lattice_normed)
        print("Lattice recip", type(lattice_recip), lattice_recip.shape, "\n", lattice_recip)
        print("Positions", type(positions_base), positions_base.shape, positions_base)
        print("Positions dot", type(positions), positions.shape, "\n", positions)
        print("kpts", k_points.shape, k_points)
        print("weights", weights.shape, weights)
        print("True shape", true_inp.shape, true_inp)
        print("species", species_arr.shape, species_arr, "\n", species_dict)
        # print("true", dft.eigenvalues[:].shape, "\n", dft.eigenvalues[dft.eigenvalues.keys()[0]][0, :, 0], "\n",
        #       dft.eigenvalues[dft.eigenvalues.keys()[0]][0, :, 1])
        print("E fermi vasp", dft.efermi)
        print("Highest occupied", occupied)
        print("E fermi calculated", fermi)
    if recip:
        return k_points, weights, lattice_recip, positions, species_arr, species_dict, true_inp, occupied
    else:
        return k_points, weights, lattice, positions, species_arr, species_dict, true_inp, occupied