コード例 #1
0
def _batch_cont(
        t1: Tensor, t2: Tensor, tensors: List[Tensor],
        network_structure: List[List], con_order: List,
        common_batch_labels: Set, labels_t1: List, labels_t2: List,
        backend_obj: AbstractBackend) -> Tuple[Tensor, List[List], List]:
    """
  Subroutine for performing a batched contraction of tensors `t1` and `t2`.
  Args:
    t1: A Tensor.
    t2: A Tensor.
    tensors: List of Tensor objects.
    network_structure: The canonical labels of the networks.
    con_order: Array of contracted labels.
    common_batch_labels: The common batch labels of `t1` and `t2`.
    labels_t1: The labels of `t1`
    labels_t2: The labels of `t2`
    backend_obj: A backend object.
  Returns:
    List[Tensor]: Updated list of tensors.
    List[List]: Updated `network_structure`.
    List: Updated `con_order` (contraction order).
  """
    common_batch_labels = list(common_batch_labels)
    #find positions of common batch labels
    t1_batch_pos = [labels_t1.index(l) for l in common_batch_labels]
    t2_batch_pos = [labels_t2.index(l) for l in common_batch_labels]
    #find positions of contracted non-batch labels
    non_batch_labels_t1 = {
        l
        for l in labels_t1 if l not in common_batch_labels
    }
    non_batch_labels_t2 = {
        l
        for l in labels_t2 if l not in common_batch_labels
    }

    common_contracted_labels = list(
        non_batch_labels_t1.intersection(non_batch_labels_t2))
    t1_cont = [labels_t1.index(l) for l in common_contracted_labels]
    t2_cont = [labels_t2.index(l) for l in common_contracted_labels]

    free_labels_t1 = set(labels_t1) - set(common_contracted_labels) - set(
        common_batch_labels)
    free_labels_t2 = set(labels_t2) - set(common_contracted_labels) - set(
        common_batch_labels)

    # find positions of uncontracted non-batch labels
    free_pos_t1 = [n for n, l in enumerate(labels_t1) if l in free_labels_t1]
    free_pos_t2 = [n for n, l in enumerate(labels_t2) if l in free_labels_t2]

    t1_shape = np.array(backend_obj.shape_tuple(t1))
    t2_shape = np.array(backend_obj.shape_tuple(t2))

    newshape_t1 = (np.prod(t1_shape[t1_batch_pos]),
                   np.prod(t1_shape[free_pos_t1]), np.prod(t1_shape[t1_cont]))
    newshape_t2 = (np.prod(t2_shape[t2_batch_pos]), np.prod(t2_shape[t2_cont]),
                   np.prod(t2_shape[free_pos_t2]))

    #bring batch labels to the front
    order_t1 = tuple(t1_batch_pos + free_pos_t1 + t1_cont)
    order_t2 = tuple(t2_batch_pos + t2_cont + free_pos_t2)

    mat1 = backend_obj.reshape(backend_obj.transpose(t1, order_t1),
                               newshape_t1)
    mat2 = backend_obj.reshape(backend_obj.transpose(t2, order_t2),
                               newshape_t2)
    result = backend_obj.matmul(mat1, mat2)
    final_shape = tuple(
        np.concatenate([
            t1_shape[t1_batch_pos], t1_shape[free_pos_t1],
            t2_shape[free_pos_t2]
        ]))
    result = backend_obj.reshape(result, final_shape)

    # update labels, tensors, network_structure and con_order
    new_labels = [labels_t1[i] for i in t1_batch_pos] + [
        labels_t1[i] for i in free_pos_t1
    ] + [labels_t2[i] for i in free_pos_t2]

    network_structure.append(new_labels)
    tensors.append(result)
    con_order = [c for c in con_order if c not in common_contracted_labels]

    return tensors, network_structure, con_order
コード例 #2
0
def _batch_cont(
    t1: Tensor, t2: Tensor, tensors: List[Tensor],
    network_structure: List[np.ndarray], con_order: np.ndarray,
    common_batch_labels: np.ndarray, labels_t1: np.ndarray,
    labels_t2: np.ndarray, backend_obj: AbstractBackend
) -> Tuple[Tensor, List[np.ndarray], np.ndarray]:
    """
  Subroutine for performing a batched contraction of tensors `t1` and `t2`.
  Args:
    t1: A Tensor.
    t2: A Tensor.
    tensors: List of Tensor objects.
    network_structure: The canonical labels of the networks.
    con_order: Array of contracted labels.
    common_batch_labels: The common batch labels of `t1` and `t2`.
    labels_t1: The labels of `t1`
    labels_t2: The labels of `t2`
    backend_obj: A backend object.
  Returns:
    List[Tensor]: Updated list of tensors.
    List[np.ndarray]: Updated `network_structure`.
    np.ndarray: Updated `con_order` (contraction order).
  """
    #find positions of common batch labels
    _, _, t1_batch_pos = np.intersect1d(common_batch_labels,
                                        labels_t1,
                                        assume_unique=True,
                                        return_indices=True)
    _, _, t2_batch_pos = np.intersect1d(common_batch_labels,
                                        labels_t2,
                                        assume_unique=True,
                                        return_indices=True)

    #find positions of contracted non-batch labels
    non_batch_labels_t1 = labels_t1[np.logical_not(
        np.isin(labels_t1, common_batch_labels))]
    non_batch_labels_t2 = labels_t2[np.logical_not(
        np.isin(labels_t2, common_batch_labels))]
    common_contracted_labels = np.intersect1d(non_batch_labels_t1,
                                              non_batch_labels_t2,
                                              assume_unique=True)
    _, _, t1_cont = np.intersect1d(common_contracted_labels,
                                   labels_t1,
                                   assume_unique=True,
                                   return_indices=True)
    _, _, t2_cont = np.intersect1d(common_contracted_labels,
                                   labels_t2,
                                   assume_unique=True,
                                   return_indices=True)

    # find positions of uncontracted non-batch labels
    free_pos_t1 = np.setdiff1d(np.arange(len(labels_t1)),
                               np.append(t1_cont, t1_batch_pos))
    free_pos_t2 = np.setdiff1d(np.arange(len(labels_t2)),
                               np.append(t2_cont, t2_batch_pos))

    t1_shape = np.array(backend_obj.shape_tuple(t1))
    t2_shape = np.array(backend_obj.shape_tuple(t2))

    newshape_t1 = (np.prod(t1_shape[t1_batch_pos]),
                   np.prod(t1_shape[free_pos_t1]), np.prod(t1_shape[t1_cont]))
    newshape_t2 = (np.prod(t2_shape[t2_batch_pos]), np.prod(t2_shape[t2_cont]),
                   np.prod(t2_shape[free_pos_t2]))

    #bring batch labels to the front
    order_t1 = tuple(np.concatenate([t1_batch_pos, free_pos_t1, t1_cont]))
    order_t2 = tuple(np.concatenate([t2_batch_pos, t2_cont, free_pos_t2]))

    mat1 = backend_obj.reshape(backend_obj.transpose(t1, order_t1),
                               newshape_t1)
    mat2 = backend_obj.reshape(backend_obj.transpose(t2, order_t2),
                               newshape_t2)
    result = backend_obj.matmul(mat1, mat2)
    final_shape = tuple(
        np.concatenate([
            t1_shape[t1_batch_pos], t1_shape[free_pos_t1],
            t2_shape[free_pos_t2]
        ]))
    result = backend_obj.reshape(result, final_shape)

    # update labels, tensors, network_structure and con_order
    new_labels = np.concatenate([
        labels_t1[t1_batch_pos], labels_t1[free_pos_t1], labels_t2[free_pos_t2]
    ])
    network_structure.append(new_labels)
    tensors.append(result)
    if len(con_order) > 0:
        con_order = np.delete(
            con_order,
            np.intersect1d(common_contracted_labels,
                           con_order,
                           assume_unique=True,
                           return_indices=True)[2])

    return tensors, network_structure, con_order