Пример #1
0
def test_abstract_backend_gmres_not_implemented():
    backend = AbstractBackend()
    with pytest.raises(NotImplementedError):
        backend.gmres(lambda x: x, np.ones((2)))
Пример #2
0
def _jittable_ncon(tensors: Sequence[Tensor], flat_labels: Tuple[int],
                   sizes: Tuple[int], con_order: Tuple[int],
                   out_order: Tuple[int], backend_obj: AbstractBackend) -> Any:
    """Jittable Ncon function.

  Args:
    tensors: List of tensors.
    network_structure: List of list of integers that descripes the network
      structure.
    con_order: Order of the contraction.
    out_order: Order of the final axis order.
    backend_obj: A backend object.

  Returns:
    The final tensor after contraction.
  """
    # some jax-juggling to avoid retracing ...
    slices = np.append(0, np.cumsum(sizes))
    network_structure = [
        np.array(flat_labels)[slices[n]:slices[n + 1]]
        for n in range(len(slices) - 1)
    ]
    con_order = np.array(con_order)
    out_order = np.array(out_order)

    # now we're ready to do stuff
    if not isinstance(tensors, list):
        raise ValueError("`tensors` is not a list")

    # partial trace
    for n, tensor in enumerate(tensors):
        tensors[n], network_structure[n], contracted_labels = _partial_trace(
            tensor, network_structure[n], backend_obj)
        con_order = np.delete(
            con_order,
            np.intersect1d(con_order, contracted_labels,
                           return_indices=True)[1])

    # binary contractions
    while len(con_order) > 0:
        cont_ind = con_order[0]  # the next index to be contracted
        locs = np.sort(
            np.nonzero([
                np.isin(cont_ind, labels) for labels in network_structure
            ])[0])
        t2 = tensors.pop(locs[1])
        t1 = tensors.pop(locs[0])
        labels_t2 = network_structure.pop(locs[1])
        labels_t1 = network_structure.pop(locs[0])

        common_labels, t1_cont, t2_cont = np.intersect1d(labels_t1,
                                                         labels_t2,
                                                         assume_unique=True,
                                                         return_indices=True)

        ind_sort = np.argsort(t1_cont)
        tensors.append(
            backend_obj.tensordot(t1,
                                  t2,
                                  axes=(tuple(t1_cont[ind_sort]),
                                        tuple(t2_cont[ind_sort]))))
        network_structure.append(
            np.append(np.delete(labels_t1, t1_cont),
                      np.delete(labels_t2, t2_cont)))

        # remove contracted labels from con_order
        con_order = np.delete(
            con_order,
            np.intersect1d(con_order,
                           common_labels,
                           assume_unique=True,
                           return_indices=True)[1])

    # outer products
    while len(tensors) > 1:
        t2 = tensors.pop()
        t1 = tensors.pop()
        labels_t2 = network_structure.pop()
        labels_t1 = network_structure.pop()

        tensors.append(backend_obj.outer_product(t1, t2))
        network_structure.append(np.append(labels_t1, labels_t2))

    # final permutation
    if len(network_structure[0]) > 0:
        i1, i2 = np.nonzero(out_order[:,
                                      None] == network_structure[0][None, :])
        return backend_obj.transpose(tensors[0], tuple(i1[i2]))
    return tensors[0]
Пример #3
0
def test_abstract_backend_zeros_not_implemented():
    backend = AbstractBackend()
    with pytest.raises(NotImplementedError):
        backend.zeros((2, 2), dtype=np.float64)
Пример #4
0
def test_abstract_backend_eigs_not_implemented():
    backend = AbstractBackend()
    with pytest.raises(NotImplementedError):
        backend.eigs(np.ones((2, 2)))
Пример #5
0
def test_abstract_backend_convert_to_tensor_not_implemented():
    backend = AbstractBackend()
    with pytest.raises(NotImplementedError):
        backend.convert_to_tensor(np.ones((2, 2)))
Пример #6
0
def test_abstract_backend_einsul_not_implemented():
    backend = AbstractBackend()
    with pytest.raises(NotImplementedError):
        backend.einsum("ii", np.ones((2, 2)), optimize=True)
Пример #7
0
def test_abstract_backend_broadcast_left_multiplication_not_implemented():
    backend = AbstractBackend()
    with pytest.raises(NotImplementedError):
        backend.broadcast_left_multiplication(np.ones((2, 2)), np.ones((2, 2)))
Пример #8
0
def test_abstract_backend_shape_concat_not_implemented():
    backend = AbstractBackend()
    with pytest.raises(NotImplementedError):
        backend.shape_concat([np.ones((2, 2)), np.ones((2, 2))], 0)
Пример #9
0
def test_abstract_backend_divide_not_implemented():
    backend = AbstractBackend()
    with pytest.raises(NotImplementedError):
        backend.divide(np.ones((2, 2)), np.ones((2, 2)))
Пример #10
0
def test_abstract_backend_index_update_not_implemented():
    backend = AbstractBackend()
    with pytest.raises(NotImplementedError):
        backend.index_update(np.ones((2, 2)), np.ones((2, 2)), np.ones((2, 2)))
Пример #11
0
def test_abstract_backend_multiply_not_implemented():
    backend = AbstractBackend()
    with pytest.raises(NotImplementedError):
        backend.multiply(np.ones((2, 2)), np.ones((2, 2)))
Пример #12
0
def _jittable_ncon(tensors: List[Tensor], flat_labels: Tuple[int],
                   sizes: Tuple[int], con_order: Tuple[int],
                   out_order: Tuple[int],
                   backend_obj: AbstractBackend) -> Tensor:
    """
  Jittable Ncon function. Performs the contraction of `tensors`.
  Args:
    tensors: List of tensors.
    flat_labels: A Tuple of integers.
    sizes: Tuple of int used to reconstruct `network_structure` from
      `flat_labels`.
    con_order: Order of the contraction.
    out_order: Order of the final axis order.
    backend_obj: A backend object.

  Returns:
    The final tensor after contraction.
  """
    # some jax-juggling to avoid retracing ...
    slices = np.append(0, np.cumsum(sizes))
    network_structure = [
        np.array(flat_labels)[slices[n]:slices[n + 1]]
        for n in range(len(slices) - 1)
    ]
    con_order = np.array(con_order)
    out_order = np.array(out_order)
    # pylint: disable=unnecessary-comprehension
    init_con_order = [c for c in con_order]
    init_network_structure = [list(c) for c in network_structure]
    # partial trace
    for n, tensor in enumerate(tensors):
        tensors[n], network_structure[n], contracted_labels = _partial_trace(
            tensor, network_structure[n], backend_obj)

        if len(contracted_labels) > 0:
            con_order = np.delete(
                con_order,
                np.intersect1d(con_order,
                               contracted_labels,
                               return_indices=True,
                               assume_unique=True)[1])

    # contracted all positive labels appearing only once in `network_structure`
    unique_labels, label_cnts = np.unique(np.concatenate(network_structure),
                                          return_counts=True)
    contractable_labels = unique_labels[np.logical_and(label_cnts == 1,
                                                       unique_labels > 0)]

    # update con_order
    if len(contractable_labels) > 0:
        con_order = np.delete(
            con_order,
            np.nonzero(np.isin(con_order, contractable_labels))[0])

    # collapse axes of single-labelled tensors
    locs = [
        n for n, labels in enumerate(network_structure)
        if np.any(np.isin(labels, contractable_labels))
    ]
    for loc in locs:
        labels = network_structure[loc]
        contractable_inds = np.nonzero(np.isin(labels, contractable_labels))[0]
        network_structure[loc] = np.delete(labels, contractable_inds)
        tensors[loc] = backend_obj.sum(tensors[loc], tuple(contractable_inds))

    # perform binary and batch contractions
    skip_counter = 0
    unique_labels, label_cnts = np.unique(np.concatenate(network_structure),
                                          return_counts=True)
    mask = np.logical_or(label_cnts > 2,
                         np.logical_and(label_cnts == 2, unique_labels < 0))
    batch_labels = unique_labels[mask]
    batch_cnts = label_cnts[mask]

    while len(con_order) > 0:
        # the next index to be contracted
        cont_ind = con_order[0]
        if cont_ind in batch_labels:
            # if its still a batch index then do it later
            con_order = np.append(np.delete(con_order, 0), cont_ind)
            skip_counter += 1
            # avoid being stuck in an infinite loop
            if skip_counter > len(con_order):
                raise ValueError(
                    f"ncon seems stuck in an infinite loop. \n"
                    f"Please check if `con_order` = {init_con_order} is "
                    f"a valid contraction order for \n"
                    f"`network_structure` = {init_network_structure}")
            continue

        # find locations of `cont_ind` in `network_structure`
        locs = [
            n for n, labels in enumerate(network_structure)
            if sum(labels == cont_ind) > 0
        ]

        t2 = tensors.pop(locs[1])
        t1 = tensors.pop(locs[0])
        labels_t2 = network_structure.pop(locs[1])
        labels_t1 = network_structure.pop(locs[0])

        common_labels, t1_cont, t2_cont = np.intersect1d(labels_t1,
                                                         labels_t2,
                                                         assume_unique=True,
                                                         return_indices=True)
        # check if there are batch labels (i.e. labels appearing more than twice
        # in `network_structure`).
        common_batch_labels = np.intersect1d(batch_labels,
                                             common_labels,
                                             assume_unique=True)
        if common_batch_labels.shape[0] > 0:
            # case1: both tensors have one or more common batch indices -> use matmul
            ix, _ = np.nonzero(
                batch_labels[:, None] == common_batch_labels[None, :])
            # reduce the counts of these labels in `batch_cnts` by 1
            batch_cnts[ix] -= 1
            # if the count of a positive label falls below 3
            # remove it from `batch_labels`
            mask = np.logical_or(
                batch_cnts > 2,
                np.logical_and(batch_cnts == 2, batch_labels < 0))
            batch_labels = batch_labels[mask]
            batch_cnts = batch_cnts[mask]

            tensors, network_structure, con_order = _batch_cont(
                t1, t2, tensors, network_structure, con_order,
                common_batch_labels, labels_t1, labels_t2, backend_obj)
        # in all other cases do a regular tensordot
        else:
            ind_sort = np.argsort(t1_cont)
            tensors.append(
                backend_obj.tensordot(t1,
                                      t2,
                                      axes=(tuple(t1_cont[ind_sort]),
                                            tuple(t2_cont[ind_sort]))))
            network_structure.append(
                np.append(np.delete(labels_t1, t1_cont),
                          np.delete(labels_t2, t2_cont)))

            # remove contracted labels from con_order
            con_order = np.delete(
                con_order,
                np.intersect1d(con_order,
                               common_labels,
                               assume_unique=True,
                               return_indices=True)[1])

    # perform outer products and remaining batch contractions
    while len(tensors) > 1:
        unique_labels, label_cnts = np.unique(
            np.concatenate(network_structure), return_counts=True)
        batch_labels = unique_labels[np.logical_or(
            label_cnts > 2, np.logical_and(label_cnts == 2,
                                           unique_labels < 0))]

        t2 = tensors.pop()
        t1 = tensors.pop()
        labels_t2 = network_structure.pop()
        labels_t1 = network_structure.pop()
        # check if there are negative batch indices left
        # (have to be collapsed to a single one)
        common_labels, t1_cont, t2_cont = np.intersect1d(labels_t1,
                                                         labels_t2,
                                                         assume_unique=True,
                                                         return_indices=True)
        common_batch_labels = np.intersect1d(batch_labels,
                                             common_labels,
                                             assume_unique=True)
        if common_batch_labels.shape[0] > 0:
            # collapse all negative batch indices
            tensors, network_structure, con_order = _batch_cont(
                t1, t2, tensors, network_structure, con_order,
                common_batch_labels, labels_t1, labels_t2, backend_obj)
        else:
            tensors.append(backend_obj.outer_product(t1, t2))
            network_structure.append(np.append(labels_t1, labels_t2))

    # if necessary do a final permutation
    if len(network_structure[0]) > 0:
        i1, i2 = np.nonzero(out_order[:,
                                      None] == network_structure[0][None, :])
        return backend_obj.transpose(tensors[0], tuple(i1[i2]))
    return tensors[0]
Пример #13
0
def _batch_cont(
    t1: Tensor, t2: Tensor, tensors: List[Tensor],
    network_structure: List[np.ndarray], con_order: np.ndarray,
    common_batch_labels: np.ndarray, labels_t1: np.ndarray,
    labels_t2: np.ndarray, backend_obj: AbstractBackend
) -> Tuple[Tensor, List[np.ndarray], np.ndarray]:
    """
  Subroutine for performing a batched contraction of tensors `t1` and `t2`.
  Args:
    t1: A Tensor.
    t2: A Tensor.
    tensors: List of Tensor objects.
    network_structure: The canonical labels of the networks.
    con_order: Array of contracted labels.
    common_batch_labels: The common batch labels of `t1` and `t2`.
    labels_t1: The labels of `t1`
    labels_t2: The labels of `t2`
    backend_obj: A backend object.
  Returns:
    List[Tensor]: Updated list of tensors.
    List[np.ndarray]: Updated `network_structure`.
    np.ndarray: Updated `con_order` (contraction order).
  """
    #find positions of common batch labels
    _, _, t1_batch_pos = np.intersect1d(common_batch_labels,
                                        labels_t1,
                                        assume_unique=True,
                                        return_indices=True)
    _, _, t2_batch_pos = np.intersect1d(common_batch_labels,
                                        labels_t2,
                                        assume_unique=True,
                                        return_indices=True)

    #find positions of contracted non-batch labels
    non_batch_labels_t1 = labels_t1[np.logical_not(
        np.isin(labels_t1, common_batch_labels))]
    non_batch_labels_t2 = labels_t2[np.logical_not(
        np.isin(labels_t2, common_batch_labels))]
    common_contracted_labels = np.intersect1d(non_batch_labels_t1,
                                              non_batch_labels_t2,
                                              assume_unique=True)
    _, _, t1_cont = np.intersect1d(common_contracted_labels,
                                   labels_t1,
                                   assume_unique=True,
                                   return_indices=True)
    _, _, t2_cont = np.intersect1d(common_contracted_labels,
                                   labels_t2,
                                   assume_unique=True,
                                   return_indices=True)

    # find positions of uncontracted non-batch labels
    free_pos_t1 = np.setdiff1d(np.arange(len(labels_t1)),
                               np.append(t1_cont, t1_batch_pos))
    free_pos_t2 = np.setdiff1d(np.arange(len(labels_t2)),
                               np.append(t2_cont, t2_batch_pos))

    t1_shape = np.array(backend_obj.shape_tuple(t1))
    t2_shape = np.array(backend_obj.shape_tuple(t2))

    newshape_t1 = (np.prod(t1_shape[t1_batch_pos]),
                   np.prod(t1_shape[free_pos_t1]), np.prod(t1_shape[t1_cont]))
    newshape_t2 = (np.prod(t2_shape[t2_batch_pos]), np.prod(t2_shape[t2_cont]),
                   np.prod(t2_shape[free_pos_t2]))

    #bring batch labels to the front
    order_t1 = tuple(np.concatenate([t1_batch_pos, free_pos_t1, t1_cont]))
    order_t2 = tuple(np.concatenate([t2_batch_pos, t2_cont, free_pos_t2]))

    mat1 = backend_obj.reshape(backend_obj.transpose(t1, order_t1),
                               newshape_t1)
    mat2 = backend_obj.reshape(backend_obj.transpose(t2, order_t2),
                               newshape_t2)
    result = backend_obj.matmul(mat1, mat2)
    final_shape = tuple(
        np.concatenate([
            t1_shape[t1_batch_pos], t1_shape[free_pos_t1],
            t2_shape[free_pos_t2]
        ]))
    result = backend_obj.reshape(result, final_shape)

    # update labels, tensors, network_structure and con_order
    new_labels = np.concatenate([
        labels_t1[t1_batch_pos], labels_t1[free_pos_t1], labels_t2[free_pos_t2]
    ])
    network_structure.append(new_labels)
    tensors.append(result)
    if len(con_order) > 0:
        con_order = np.delete(
            con_order,
            np.intersect1d(common_contracted_labels,
                           con_order,
                           assume_unique=True,
                           return_indices=True)[2])

    return tensors, network_structure, con_order
Пример #14
0
def test_abstract_backend_slice_not_implemented():
    backend = AbstractBackend()
    with pytest.raises(NotImplementedError):
        backend.slice(np.ones((2, 2)), (0, 1), (1, 1))
def test_switch_backend_raises_error(backend):
    a = tn.Node(np.random.rand(3, 3, 3))
    a.backend = AbstractBackend()
    with pytest.raises(NotImplementedError):
        tn.switch_backend({a}, backend)
Пример #16
0
def test_abstract_backend_rq_decompositon_not_implemented():
    backend = AbstractBackend()
    with pytest.raises(NotImplementedError):
        backend.rq_decomposition(np.ones((2, 2)), 0)
Пример #17
0
def _batch_cont(
        t1: Tensor, t2: Tensor, tensors: List[Tensor],
        network_structure: List[List], con_order: List,
        common_batch_labels: Set, labels_t1: List, labels_t2: List,
        backend_obj: AbstractBackend) -> Tuple[Tensor, List[List], List]:
    """
  Subroutine for performing a batched contraction of tensors `t1` and `t2`.
  Args:
    t1: A Tensor.
    t2: A Tensor.
    tensors: List of Tensor objects.
    network_structure: The canonical labels of the networks.
    con_order: Array of contracted labels.
    common_batch_labels: The common batch labels of `t1` and `t2`.
    labels_t1: The labels of `t1`
    labels_t2: The labels of `t2`
    backend_obj: A backend object.
  Returns:
    List[Tensor]: Updated list of tensors.
    List[List]: Updated `network_structure`.
    List: Updated `con_order` (contraction order).
  """
    common_batch_labels = list(common_batch_labels)
    #find positions of common batch labels
    t1_batch_pos = [labels_t1.index(l) for l in common_batch_labels]
    t2_batch_pos = [labels_t2.index(l) for l in common_batch_labels]
    #find positions of contracted non-batch labels
    non_batch_labels_t1 = {
        l
        for l in labels_t1 if l not in common_batch_labels
    }
    non_batch_labels_t2 = {
        l
        for l in labels_t2 if l not in common_batch_labels
    }

    common_contracted_labels = list(
        non_batch_labels_t1.intersection(non_batch_labels_t2))
    t1_cont = [labels_t1.index(l) for l in common_contracted_labels]
    t2_cont = [labels_t2.index(l) for l in common_contracted_labels]

    free_labels_t1 = set(labels_t1) - set(common_contracted_labels) - set(
        common_batch_labels)
    free_labels_t2 = set(labels_t2) - set(common_contracted_labels) - set(
        common_batch_labels)

    # find positions of uncontracted non-batch labels
    free_pos_t1 = [n for n, l in enumerate(labels_t1) if l in free_labels_t1]
    free_pos_t2 = [n for n, l in enumerate(labels_t2) if l in free_labels_t2]

    t1_shape = np.array(backend_obj.shape_tuple(t1))
    t2_shape = np.array(backend_obj.shape_tuple(t2))

    newshape_t1 = (np.prod(t1_shape[t1_batch_pos]),
                   np.prod(t1_shape[free_pos_t1]), np.prod(t1_shape[t1_cont]))
    newshape_t2 = (np.prod(t2_shape[t2_batch_pos]), np.prod(t2_shape[t2_cont]),
                   np.prod(t2_shape[free_pos_t2]))

    #bring batch labels to the front
    order_t1 = tuple(t1_batch_pos + free_pos_t1 + t1_cont)
    order_t2 = tuple(t2_batch_pos + t2_cont + free_pos_t2)

    mat1 = backend_obj.reshape(backend_obj.transpose(t1, order_t1),
                               newshape_t1)
    mat2 = backend_obj.reshape(backend_obj.transpose(t2, order_t2),
                               newshape_t2)
    result = backend_obj.matmul(mat1, mat2)
    final_shape = tuple(
        np.concatenate([
            t1_shape[t1_batch_pos], t1_shape[free_pos_t1],
            t2_shape[free_pos_t2]
        ]))
    result = backend_obj.reshape(result, final_shape)

    # update labels, tensors, network_structure and con_order
    new_labels = [labels_t1[i] for i in t1_batch_pos] + [
        labels_t1[i] for i in free_pos_t1
    ] + [labels_t2[i] for i in free_pos_t2]

    network_structure.append(new_labels)
    tensors.append(result)
    con_order = [c for c in con_order if c not in common_contracted_labels]

    return tensors, network_structure, con_order
Пример #18
0
def test_abstract_backend_shape_prod_not_implemented():
    backend = AbstractBackend()
    with pytest.raises(NotImplementedError):
        backend.shape_prod(np.ones((2, 2)))
Пример #19
0
def _jittable_ncon(tensors: List[Tensor], flat_labels: Tuple[int],
                   sizes: Tuple[int], con_order: Tuple[int],
                   out_order: Tuple[int],
                   backend_obj: AbstractBackend) -> Tensor:
    """
  Jittable Ncon function. Performs the contraction of `tensors`.
  Args:
    tensors: List of tensors.
    flat_labels: A Tuple of integers.
    sizes: Tuple of int used to reconstruct `network_structure` from
      `flat_labels`.
    con_order: Order of the contraction.
    out_order: Order of the final axis order.
    backend_obj: A backend object.

  Returns:
    The final tensor after contraction.
  """
    # some jax-juggling to avoid retracing ...
    flat_labels = list(flat_labels)
    slices = np.append(0, np.cumsum(sizes))
    network_structure = [
        flat_labels[slices[n]:slices[n + 1]] for n in range(len(slices) - 1)
    ]
    out_order = list(out_order)
    con_order = list(con_order)
    # pylint: disable=unnecessary-comprehension
    init_con_order = [c for c in con_order]
    init_network_structure = [c for c in network_structure]

    # partial trace
    for n, tensor in enumerate(tensors):
        tensors[n], network_structure[n], contracted_labels = _partial_trace(
            tensor, network_structure[n], backend_obj)
        if len(contracted_labels) > 0:
            con_order = [c for c in con_order if c not in contracted_labels]

    flat_labels = [l for sublist in network_structure for l in sublist]
    # contracted all positive labels appearing only once in `network_structure`
    contractable_labels = [
        l for l in flat_labels if (flat_labels.count(l) == 1) and (l > 0)
    ]
    # update con_order
    if len(contractable_labels) > 0:
        con_order = [o for o in con_order if o not in contractable_labels]
    # collapse axes of single-labelled tensors
    locs = []
    for n, labels in enumerate(network_structure):
        if len(set(labels).intersection(contractable_labels)) > 0:
            locs.append(n)

    for loc in locs:
        labels = network_structure[loc]
        contractable_inds = [labels.index(l) for l in contractable_labels]
        network_structure[loc] = [
            l for l in labels if l not in contractable_labels
        ]
        tensors[loc] = backend_obj.sum(tensors[loc], tuple(contractable_inds))

    # perform binary and batch contractions
    skip_counter = 0
    batch_labels = []
    batch_cnts = []
    for l in set(flat_labels):
        cnt = flat_labels.count(l)
        if (cnt > 2) or (cnt == 2 and l < 0):
            batch_labels.append(l)
            batch_cnts.append(cnt)

    while len(con_order) > 0:
        # the next index to be contracted
        cont_ind = con_order[0]
        if cont_ind in batch_labels:
            # if its still a batch index then do it later
            con_order.append(con_order.pop(0))
            skip_counter += 1
            # avoid being stuck in an infinite loop
            if skip_counter > len(con_order):
                raise ValueError(
                    f"ncon seems stuck in an infinite loop. \n"
                    f"Please check if `con_order` = {init_con_order} is "
                    f"a valid contraction order for \n"
                    f"`network_structure` = {init_network_structure}")
            continue

        # find locations of `cont_ind` in `network_structure`
        locs = [
            n for n, labels in enumerate(network_structure)
            if cont_ind in labels
        ]

        t2 = tensors.pop(locs[1])
        t1 = tensors.pop(locs[0])
        labels_t2 = network_structure.pop(locs[1])
        labels_t1 = network_structure.pop(locs[0])
        common_labels, t1_cont, t2_cont = label_intersection(
            labels_t1, labels_t2)
        # check if there are batch labels (i.e. labels appearing more than twice
        # in `network_structure`).
        common_batch_labels = set(batch_labels).intersection(common_labels)
        if len(common_batch_labels) > 0:
            # case1: both tensors have one or more common batch indices -> use matmul
            ix = np.nonzero(
                np.array(batch_labels)[:, None] == np.array(
                    list(common_batch_labels))[None, :])[0]
            # reduce the counts of these labels in `batch_cnts` by 1
            delete = []
            for i in ix:
                batch_cnts[i] -= 1
                if (batch_labels[i] > 0) and (batch_cnts[i] <= 2):
                    delete.append(i)
                elif (batch_labels[i] < 0) and (batch_cnts[i] < 2):
                    delete.append(i)

            for i in sorted(delete, reverse=True):
                del batch_cnts[i]
                del batch_labels[i]

            tensors, network_structure, con_order = _batch_cont(
                t1, t2, tensors, network_structure, con_order,
                common_batch_labels, labels_t1, labels_t2, backend_obj)
        # in all other cases do a regular tensordot
        else:
            # for len(t1_cont)~<20 this is faster than np.argsort
            ind_sort = [t1_cont.index(l) for l in sorted(t1_cont)]
            tensors.append(
                backend_obj.tensordot(
                    t1,
                    t2,
                    axes=(tuple([t1_cont[i] for i in ind_sort]),
                          tuple([t2_cont[i] for i in ind_sort]))))
            new_labels = [l for l in labels_t1 if l not in common_labels
                          ] + [l for l in labels_t2 if l not in common_labels]
            network_structure.append(new_labels)
            # remove contracted labels from con_order
            con_order = [c for c in con_order if c not in common_labels]

    # perform outer products and remaining batch contractions
    while len(tensors) > 1:
        t2 = tensors.pop()
        t1 = tensors.pop()
        labels_t2 = network_structure.pop()
        labels_t1 = network_structure.pop()
        # check if there are negative batch indices left
        # (have to be collapsed to a single one)
        common_labels, t1_cont, t2_cont = label_intersection(
            labels_t1, labels_t2)
        common_batch_labels = set(batch_labels).intersection(common_labels)
        if len(common_batch_labels) > 0:
            # collapse all negative batch indices
            tensors, network_structure, con_order = _batch_cont(
                t1, t2, tensors, network_structure, con_order,
                common_batch_labels, labels_t1, labels_t2, backend_obj)
        else:
            tensors.append(backend_obj.outer_product(t1, t2))
            network_structure.append(labels_t1 + labels_t2)

    # if necessary do a final permutation
    if len(network_structure[0]) > 1:
        labels = network_structure[0]
        final_order = tuple([labels.index(l) for l in out_order])
        return backend_obj.transpose(tensors[0], final_order)
    return tensors[0]
Пример #20
0
def test_abstract_backend_outer_product_not_implemented():
    backend = AbstractBackend()
    with pytest.raises(NotImplementedError):
        backend.outer_product(np.ones((2, 2)), np.ones((2, 2)))
Пример #21
0
def test_abstract_backend_name():
    backend = AbstractBackend()
    assert backend.name == "abstract backend"
Пример #22
0
def test_abstract_backend_eye_not_implemented():
    backend = AbstractBackend()
    with pytest.raises(NotImplementedError):
        backend.eye(2, dtype=np.float64)
Пример #23
0
def test_abstract_backend_tensordot_not_implemented():
    backend = AbstractBackend()
    with pytest.raises(NotImplementedError):
        backend.tensordot(np.ones((2, 2)), np.ones((2, 2)), axes=[[0], [0]])
Пример #24
0
def test_abstract_backend_random_uniforl_not_implemented():
    backend = AbstractBackend()
    with pytest.raises(NotImplementedError):
        backend.random_uniform((2, 2))
Пример #25
0
def test_abstract_backend_reshape_not_implemented():
    backend = AbstractBackend()
    with pytest.raises(NotImplementedError):
        backend.reshape(np.ones((2, 2)), (4, 1))
Пример #26
0
def test_abstract_backend_eigs_lanczos_not_implemented():
    backend = AbstractBackend()
    with pytest.raises(NotImplementedError):
        backend.eigsh_lanczos(lambda x: x, np.ones((2)))
Пример #27
0
def test_abstract_backend_transpose_not_implemented():
    backend = AbstractBackend()
    with pytest.raises(NotImplementedError):
        backend.transpose(np.ones((2, 2)), [0, 1])
Пример #28
0
def test_abstract_backend_subtraction_not_implemented():
    backend = AbstractBackend()
    with pytest.raises(NotImplementedError):
        backend.subtraction(np.ones((2, 2)), np.ones((2, 2)))
Пример #29
0
def test_pivot_not_implemented():
  backend = AbstractBackend()
  with pytest.raises(NotImplementedError):
    backend.pivot(np.ones((2, 2)))