def test_abstract_backend_outer_product_not_implemented():
  backend = AbstractBackend()
  with pytest.raises(NotImplementedError):
    backend.outer_product(np.ones((2, 2)), np.ones((2, 2)))
def _jittable_ncon(tensors: List[Tensor], flat_labels: Tuple[int],
                   sizes: Tuple[int], con_order: Tuple[int],
                   out_order: Tuple[int],
                   backend_obj: AbstractBackend) -> Tensor:
    """
  Jittable Ncon function. Performs the contraction of `tensors`.
  Args:
    tensors: List of tensors.
    flat_labels: A Tuple of integers.
    sizes: Tuple of int used to reconstruct `network_structure` from
      `flat_labels`.
    con_order: Order of the contraction.
    out_order: Order of the final axis order.
    backend_obj: A backend object.

  Returns:
    The final tensor after contraction.
  """
    # some jax-juggling to avoid retracing ...
    slices = np.append(0, np.cumsum(sizes))
    network_structure = [
        np.array(flat_labels)[slices[n]:slices[n + 1]]
        for n in range(len(slices) - 1)
    ]
    con_order = np.array(con_order)
    out_order = np.array(out_order)
    # pylint: disable=unnecessary-comprehension
    init_con_order = [c for c in con_order]
    init_network_structure = [list(c) for c in network_structure]
    # partial trace
    for n, tensor in enumerate(tensors):
        tensors[n], network_structure[n], contracted_labels = _partial_trace(
            tensor, network_structure[n], backend_obj)

        if len(contracted_labels) > 0:
            con_order = np.delete(
                con_order,
                np.intersect1d(con_order,
                               contracted_labels,
                               return_indices=True,
                               assume_unique=True)[1])

    # contracted all positive labels appearing only once in `network_structure`
    unique_labels, label_cnts = np.unique(np.concatenate(network_structure),
                                          return_counts=True)
    contractable_labels = unique_labels[np.logical_and(label_cnts == 1,
                                                       unique_labels > 0)]

    # update con_order
    if len(contractable_labels) > 0:
        con_order = np.delete(
            con_order,
            np.nonzero(np.isin(con_order, contractable_labels))[0])

    # collapse axes of single-labelled tensors
    locs = [
        n for n, labels in enumerate(network_structure)
        if np.any(np.isin(labels, contractable_labels))
    ]
    for loc in locs:
        labels = network_structure[loc]
        contractable_inds = np.nonzero(np.isin(labels, contractable_labels))[0]
        network_structure[loc] = np.delete(labels, contractable_inds)
        tensors[loc] = backend_obj.sum(tensors[loc], tuple(contractable_inds))

    # perform binary and batch contractions
    skip_counter = 0
    unique_labels, label_cnts = np.unique(np.concatenate(network_structure),
                                          return_counts=True)
    mask = np.logical_or(label_cnts > 2,
                         np.logical_and(label_cnts == 2, unique_labels < 0))
    batch_labels = unique_labels[mask]
    batch_cnts = label_cnts[mask]

    while len(con_order) > 0:
        # the next index to be contracted
        cont_ind = con_order[0]
        if cont_ind in batch_labels:
            # if its still a batch index then do it later
            con_order = np.append(np.delete(con_order, 0), cont_ind)
            skip_counter += 1
            # avoid being stuck in an infinite loop
            if skip_counter > len(con_order):
                raise ValueError(
                    f"ncon seems stuck in an infinite loop. \n"
                    f"Please check if `con_order` = {init_con_order} is "
                    f"a valid contraction order for \n"
                    f"`network_structure` = {init_network_structure}")
            continue

        # find locations of `cont_ind` in `network_structure`
        locs = [
            n for n, labels in enumerate(network_structure)
            if sum(labels == cont_ind) > 0
        ]

        t2 = tensors.pop(locs[1])
        t1 = tensors.pop(locs[0])
        labels_t2 = network_structure.pop(locs[1])
        labels_t1 = network_structure.pop(locs[0])

        common_labels, t1_cont, t2_cont = np.intersect1d(labels_t1,
                                                         labels_t2,
                                                         assume_unique=True,
                                                         return_indices=True)
        # check if there are batch labels (i.e. labels appearing more than twice
        # in `network_structure`).
        common_batch_labels = np.intersect1d(batch_labels,
                                             common_labels,
                                             assume_unique=True)
        if common_batch_labels.shape[0] > 0:
            # case1: both tensors have one or more common batch indices -> use matmul
            ix, _ = np.nonzero(
                batch_labels[:, None] == common_batch_labels[None, :])
            # reduce the counts of these labels in `batch_cnts` by 1
            batch_cnts[ix] -= 1
            # if the count of a positive label falls below 3
            # remove it from `batch_labels`
            mask = np.logical_or(
                batch_cnts > 2,
                np.logical_and(batch_cnts == 2, batch_labels < 0))
            batch_labels = batch_labels[mask]
            batch_cnts = batch_cnts[mask]

            tensors, network_structure, con_order = _batch_cont(
                t1, t2, tensors, network_structure, con_order,
                common_batch_labels, labels_t1, labels_t2, backend_obj)
        # in all other cases do a regular tensordot
        else:
            ind_sort = np.argsort(t1_cont)
            tensors.append(
                backend_obj.tensordot(t1,
                                      t2,
                                      axes=(tuple(t1_cont[ind_sort]),
                                            tuple(t2_cont[ind_sort]))))
            network_structure.append(
                np.append(np.delete(labels_t1, t1_cont),
                          np.delete(labels_t2, t2_cont)))

            # remove contracted labels from con_order
            con_order = np.delete(
                con_order,
                np.intersect1d(con_order,
                               common_labels,
                               assume_unique=True,
                               return_indices=True)[1])

    # perform outer products and remaining batch contractions
    while len(tensors) > 1:
        unique_labels, label_cnts = np.unique(
            np.concatenate(network_structure), return_counts=True)
        batch_labels = unique_labels[np.logical_or(
            label_cnts > 2, np.logical_and(label_cnts == 2,
                                           unique_labels < 0))]

        t2 = tensors.pop()
        t1 = tensors.pop()
        labels_t2 = network_structure.pop()
        labels_t1 = network_structure.pop()
        # check if there are negative batch indices left
        # (have to be collapsed to a single one)
        common_labels, t1_cont, t2_cont = np.intersect1d(labels_t1,
                                                         labels_t2,
                                                         assume_unique=True,
                                                         return_indices=True)
        common_batch_labels = np.intersect1d(batch_labels,
                                             common_labels,
                                             assume_unique=True)
        if common_batch_labels.shape[0] > 0:
            # collapse all negative batch indices
            tensors, network_structure, con_order = _batch_cont(
                t1, t2, tensors, network_structure, con_order,
                common_batch_labels, labels_t1, labels_t2, backend_obj)
        else:
            tensors.append(backend_obj.outer_product(t1, t2))
            network_structure.append(np.append(labels_t1, labels_t2))

    # if necessary do a final permutation
    if len(network_structure[0]) > 0:
        i1, i2 = np.nonzero(out_order[:,
                                      None] == network_structure[0][None, :])
        return backend_obj.transpose(tensors[0], tuple(i1[i2]))
    return tensors[0]
Exemple #3
0
def _jittable_ncon(tensors: List[Tensor], flat_labels: Tuple[int],
                   sizes: Tuple[int], con_order: Tuple[int],
                   out_order: Tuple[int],
                   backend_obj: AbstractBackend) -> Tensor:
    """
  Jittable Ncon function. Performs the contraction of `tensors`.
  Args:
    tensors: List of tensors.
    flat_labels: A Tuple of integers.
    sizes: Tuple of int used to reconstruct `network_structure` from
      `flat_labels`.
    con_order: Order of the contraction.
    out_order: Order of the final axis order.
    backend_obj: A backend object.

  Returns:
    The final tensor after contraction.
  """
    # some jax-juggling to avoid retracing ...
    flat_labels = list(flat_labels)
    slices = np.append(0, np.cumsum(sizes))
    network_structure = [
        flat_labels[slices[n]:slices[n + 1]] for n in range(len(slices) - 1)
    ]
    out_order = list(out_order)
    con_order = list(con_order)
    # pylint: disable=unnecessary-comprehension
    init_con_order = [c for c in con_order]
    init_network_structure = [c for c in network_structure]

    # partial trace
    for n, tensor in enumerate(tensors):
        tensors[n], network_structure[n], contracted_labels = _partial_trace(
            tensor, network_structure[n], backend_obj)
        if len(contracted_labels) > 0:
            con_order = [c for c in con_order if c not in contracted_labels]

    flat_labels = [l for sublist in network_structure for l in sublist]
    # contracted all positive labels appearing only once in `network_structure`
    contractable_labels = [
        l for l in flat_labels if (flat_labels.count(l) == 1) and (l > 0)
    ]
    # update con_order
    if len(contractable_labels) > 0:
        con_order = [o for o in con_order if o not in contractable_labels]
    # collapse axes of single-labelled tensors
    locs = []
    for n, labels in enumerate(network_structure):
        if len(set(labels).intersection(contractable_labels)) > 0:
            locs.append(n)

    for loc in locs:
        labels = network_structure[loc]
        contractable_inds = [labels.index(l) for l in contractable_labels]
        network_structure[loc] = [
            l for l in labels if l not in contractable_labels
        ]
        tensors[loc] = backend_obj.sum(tensors[loc], tuple(contractable_inds))

    # perform binary and batch contractions
    skip_counter = 0
    batch_labels = []
    batch_cnts = []
    for l in set(flat_labels):
        cnt = flat_labels.count(l)
        if (cnt > 2) or (cnt == 2 and l < 0):
            batch_labels.append(l)
            batch_cnts.append(cnt)

    while len(con_order) > 0:
        # the next index to be contracted
        cont_ind = con_order[0]
        if cont_ind in batch_labels:
            # if its still a batch index then do it later
            con_order.append(con_order.pop(0))
            skip_counter += 1
            # avoid being stuck in an infinite loop
            if skip_counter > len(con_order):
                raise ValueError(
                    f"ncon seems stuck in an infinite loop. \n"
                    f"Please check if `con_order` = {init_con_order} is "
                    f"a valid contraction order for \n"
                    f"`network_structure` = {init_network_structure}")
            continue

        # find locations of `cont_ind` in `network_structure`
        locs = [
            n for n, labels in enumerate(network_structure)
            if cont_ind in labels
        ]

        t2 = tensors.pop(locs[1])
        t1 = tensors.pop(locs[0])
        labels_t2 = network_structure.pop(locs[1])
        labels_t1 = network_structure.pop(locs[0])
        common_labels, t1_cont, t2_cont = label_intersection(
            labels_t1, labels_t2)
        # check if there are batch labels (i.e. labels appearing more than twice
        # in `network_structure`).
        common_batch_labels = set(batch_labels).intersection(common_labels)
        if len(common_batch_labels) > 0:
            # case1: both tensors have one or more common batch indices -> use matmul
            ix = np.nonzero(
                np.array(batch_labels)[:, None] == np.array(
                    list(common_batch_labels))[None, :])[0]
            # reduce the counts of these labels in `batch_cnts` by 1
            delete = []
            for i in ix:
                batch_cnts[i] -= 1
                if (batch_labels[i] > 0) and (batch_cnts[i] <= 2):
                    delete.append(i)
                elif (batch_labels[i] < 0) and (batch_cnts[i] < 2):
                    delete.append(i)

            for i in sorted(delete, reverse=True):
                del batch_cnts[i]
                del batch_labels[i]

            tensors, network_structure, con_order = _batch_cont(
                t1, t2, tensors, network_structure, con_order,
                common_batch_labels, labels_t1, labels_t2, backend_obj)
        # in all other cases do a regular tensordot
        else:
            # for len(t1_cont)~<20 this is faster than np.argsort
            ind_sort = [t1_cont.index(l) for l in sorted(t1_cont)]
            tensors.append(
                backend_obj.tensordot(
                    t1,
                    t2,
                    axes=(tuple([t1_cont[i] for i in ind_sort]),
                          tuple([t2_cont[i] for i in ind_sort]))))
            new_labels = [l for l in labels_t1 if l not in common_labels
                          ] + [l for l in labels_t2 if l not in common_labels]
            network_structure.append(new_labels)
            # remove contracted labels from con_order
            con_order = [c for c in con_order if c not in common_labels]

    # perform outer products and remaining batch contractions
    while len(tensors) > 1:
        t2 = tensors.pop()
        t1 = tensors.pop()
        labels_t2 = network_structure.pop()
        labels_t1 = network_structure.pop()
        # check if there are negative batch indices left
        # (have to be collapsed to a single one)
        common_labels, t1_cont, t2_cont = label_intersection(
            labels_t1, labels_t2)
        common_batch_labels = set(batch_labels).intersection(common_labels)
        if len(common_batch_labels) > 0:
            # collapse all negative batch indices
            tensors, network_structure, con_order = _batch_cont(
                t1, t2, tensors, network_structure, con_order,
                common_batch_labels, labels_t1, labels_t2, backend_obj)
        else:
            tensors.append(backend_obj.outer_product(t1, t2))
            network_structure.append(labels_t1 + labels_t2)

    # if necessary do a final permutation
    if len(network_structure[0]) > 1:
        labels = network_structure[0]
        final_order = tuple([labels.index(l) for l in out_order])
        return backend_obj.transpose(tensors[0], final_order)
    return tensors[0]
Exemple #4
0
def _jittable_ncon(tensors: Sequence[Tensor], flat_labels: Tuple[int],
                   sizes: Tuple[int], con_order: Tuple[int],
                   out_order: Tuple[int], backend_obj: AbstractBackend) -> Any:
    """Jittable Ncon function.

  Args:
    tensors: List of tensors.
    network_structure: List of list of integers that descripes the network
      structure.
    con_order: Order of the contraction.
    out_order: Order of the final axis order.
    backend_obj: A backend object.

  Returns:
    The final tensor after contraction.
  """
    # some jax-juggling to avoid retracing ...
    slices = np.append(0, np.cumsum(sizes))
    network_structure = [
        np.array(flat_labels)[slices[n]:slices[n + 1]]
        for n in range(len(slices) - 1)
    ]
    con_order = np.array(con_order)
    out_order = np.array(out_order)

    # now we're ready to do stuff
    if not isinstance(tensors, list):
        raise ValueError("`tensors` is not a list")

    # partial trace
    for n, tensor in enumerate(tensors):
        tensors[n], network_structure[n], contracted_labels = _partial_trace(
            tensor, network_structure[n], backend_obj)
        con_order = np.delete(
            con_order,
            np.intersect1d(con_order, contracted_labels,
                           return_indices=True)[1])

    # binary contractions
    while len(con_order) > 0:
        cont_ind = con_order[0]  # the next index to be contracted
        locs = np.sort(
            np.nonzero([
                np.isin(cont_ind, labels) for labels in network_structure
            ])[0])
        t2 = tensors.pop(locs[1])
        t1 = tensors.pop(locs[0])
        labels_t2 = network_structure.pop(locs[1])
        labels_t1 = network_structure.pop(locs[0])

        common_labels, t1_cont, t2_cont = np.intersect1d(labels_t1,
                                                         labels_t2,
                                                         assume_unique=True,
                                                         return_indices=True)

        ind_sort = np.argsort(t1_cont)
        tensors.append(
            backend_obj.tensordot(t1,
                                  t2,
                                  axes=(tuple(t1_cont[ind_sort]),
                                        tuple(t2_cont[ind_sort]))))
        network_structure.append(
            np.append(np.delete(labels_t1, t1_cont),
                      np.delete(labels_t2, t2_cont)))

        # remove contracted labels from con_order
        con_order = np.delete(
            con_order,
            np.intersect1d(con_order,
                           common_labels,
                           assume_unique=True,
                           return_indices=True)[1])

    # outer products
    while len(tensors) > 1:
        t2 = tensors.pop()
        t1 = tensors.pop()
        labels_t2 = network_structure.pop()
        labels_t1 = network_structure.pop()

        tensors.append(backend_obj.outer_product(t1, t2))
        network_structure.append(np.append(labels_t1, labels_t2))

    # final permutation
    if len(network_structure[0]) > 0:
        i1, i2 = np.nonzero(out_order[:,
                                      None] == network_structure[0][None, :])
        return backend_obj.transpose(tensors[0], tuple(i1[i2]))
    return tensors[0]