def assign_vector(self, full_state: tf.Tensor): """Splits a full state vector and assigns it to the ``tf.Variable`` pieces. Args: full_state (tf.Tensor): Full state vector as a tensor of shape ``(2 ** nqubits)``. """ with tf.device(self.device): full_state = tf.reshape(full_state, self.shapes["device"]) pieces = [full_state[i] for i in range(self.ndevices)] new_state = tf.zeros(self.shapes["device"], dtype=self.dtype) new_state = op.transpose_state(pieces, new_state, self.nqubits, self.qubits.transpose_order) for i in range(self.ndevices): self.pieces[i].assign(new_state[i])
def vector(self) -> tf.Tensor: """Returns the full state vector as a ``tf.Tensor`` of shape ``(2 ** nqubits,)``. This is done by merging the state pieces to a single tensor. Using this method will double memory usage. """ if self.qubits.list == list(range(self.nglobal)): with tf.device(self.device): state = tf.concat([x[tf.newaxis] for x in self.pieces], axis=0) state = tf.reshape(state, self.shapes["full"]) elif self.qubits.list == list(range(self.nlocal, self.nqubits)): with tf.device(self.device): state = tf.concat([x[:, tf.newaxis] for x in self.pieces], axis=1) state = tf.reshape(state, self.shapes["full"]) else: # fall back to the transpose op with tf.device(self.device): state = tf.zeros(self.shapes["full"], dtype=self.dtype) state = op.transpose_state(self.pieces, state, self.nqubits, self.qubits.reverse_transpose_order) return state
def test_transpose_state(nqubits, ndevices): for _ in range(10): # Generate global qubits randomly all_qubits = np.arange(nqubits) np.random.shuffle(all_qubits) qubit_order = list(all_qubits) state = utils.random_tensorflow_complex((2 ** nqubits,), dtype=tf.float64) state_tensor = state.numpy().reshape(nqubits * (2,)) target_state = np.transpose(state_tensor, qubit_order).ravel() new_state = tf.zeros_like(state) shape = (ndevices, int(state.shape[0]) // ndevices) state = tf.reshape(state, shape) pieces = [state[i] for i in range(ndevices)] if tf.config.list_physical_devices("GPU"): # pragma: no cover # case not tested by GitHub workflows because it requires GPU check_unimplemented_error(op.transpose_state, pieces, new_state, nqubits, qubit_order, get_threads()) else: new_state = op.transpose_state(pieces, new_state, nqubits, qubit_order, get_threads()) np.testing.assert_allclose(target_state, new_state.numpy())