Example #1
0
    def __init__(self,
                 x=None,
                 edge_index=None,
                 edge_attr=None,
                 y=None,
                 pos=None,
                 norm=None,
                 face=None,
                 **kwargs):
        self.x = x
        self.edge_index = edge_index
        self.edge_attr = edge_attr
        self.y = y
        self.pos = pos
        self.norm = norm
        self.face = face
        for key, item in kwargs.items():
            if key == 'num_nodes':
                self.__num_nodes__ = item
            else:
                self[key] = item

        if edge_index is not None and edge_index.dtype != torch.long:
            raise ValueError(
                (f'Argument `edge_index` needs to be of type `torch.long` but '
                 f'found type `{edge_index.dtype}`.'))

        if face is not None and face.dtype != torch.long:
            raise ValueError(
                (f'Argument `face` needs to be of type `torch.long` but found '
                 f'type `{face.dtype}`.'))

        if torch_geometric.is_debug_enabled():
            self.debug()
Example #2
0
    def from_data_list(data_list, follow_batch=[]):
        r"""Constructs a batch object from a python list holding
        :class:`torch_geometric.data.Data` objects.
        The assignment vector :obj:`batch` is created on the fly.
        Additionally, creates assignment batch vectors for each key in
        :obj:`follow_batch`."""
        for data in data_list:
            assert isinstance(data, MultiScaleData)
        num_scales = data_list[0].num_scales
        for data_entry in data_list:
            assert data_entry.num_scales == num_scales, "All data objects should contain the same number of scales"

        # Build multiscale batches
        multiscale = []
        for scale in range(num_scales):
            ms_scale = []
            for data_entry in data_list:
                ms_scale.append(data_entry.multiscale[scale])
            multiscale.append(from_data_list_token(ms_scale))

        # Create batch from non multiscale data
        for data_entry in data_list:
            del data_entry.multiscale
        batch = Batch.from_data_list(data_list)
        batch = MultiScaleBatch.from_data(batch)
        batch.multiscale = multiscale

        if torch_geometric.is_debug_enabled():
            batch.debug()

        return batch
Example #3
0
def from_data_list_token(data_list, follow_batch=[]):
    """ This is pretty a copy paste of the from data list of pytorch geometric
    batch object with the difference that indexes that are negative are not incremented
    """

    keys = [set(data.keys) for data in data_list]
    keys = list(set.union(*keys))
    assert "batch" not in keys

    batch = Batch()
    batch.__data_class__ = data_list[0].__class__
    batch.__slices__ = {key: [0] for key in keys}

    for key in keys:
        batch[key] = []

    for key in follow_batch:
        batch["{}_batch".format(key)] = []

    cumsum = {key: 0 for key in keys}
    batch.batch = []
    for i, data in enumerate(data_list):
        for key in data.keys:
            item = data[key]
            if torch.is_tensor(item) and item.dtype != torch.bool:
                mask = item >= 0
                item[mask] = item[mask] + cumsum[key]
            if torch.is_tensor(item):
                size = item.size(data.__cat_dim__(key, data[key]))
            else:
                size = 1
            batch.__slices__[key].append(size + batch.__slices__[key][-1])
            cumsum[key] += data.__inc__(key, item)
            batch[key].append(item)

            if key in follow_batch:
                item = torch.full((size,), i, dtype=torch.long)
                batch["{}_batch".format(key)].append(item)

        num_nodes = data.num_nodes
        if num_nodes is not None:
            item = torch.full((num_nodes,), i, dtype=torch.long)
            batch.batch.append(item)

    if num_nodes is None:
        batch.batch = None

    for key in batch.keys:
        item = batch[key][0]
        if torch.is_tensor(item):
            batch[key] = torch.cat(batch[key], dim=data_list[0].__cat_dim__(key, item))
        elif isinstance(item, int) or isinstance(item, float):
            batch[key] = torch.tensor(batch[key])
        else:
            raise ValueError("Unsupported attribute type {} : {}".format(type(item), item))

    if torch_geometric.is_debug_enabled():
        batch.debug()

    return batch.contiguous()
    def __produce_bipartite_data_flow__(self, n_id):
        r"""Produces a :obj:`DataFlow` object with a bipartite assignment
        matrix for a given mini-batch :obj:`n_id`."""
        data_flow = DataFlow(n_id, self.flow)

        for l in range(self.num_hops):
            e_id = neighbor_sampler(n_id, self.cumdeg, self.size[l])

            new_n_id = self.edge_index_j.index_select(0, e_id)
            e_id = self.e_assoc[e_id]

            if self.add_self_loops:
                new_n_id = torch.cat([new_n_id, n_id], dim=0)
                new_n_id, inv = new_n_id.unique(sorted=False,
                                                return_inverse=True)
                res_n_id = inv[-n_id.size(0):]
            else:
                new_n_id = new_n_id.unique(sorted=False)
                res_n_id = None

            edges = [None, None]
            edge_index_i = self.edge_index[self.i, e_id]
            if self.add_self_loops:
                edge_index_i = torch.cat([edge_index_i, n_id], dim=0)

            self.tmp[n_id] = torch.arange(n_id.size(0))
            edges[self.i] = self.tmp[edge_index_i]

            edge_index_j = self.edge_index[self.j, e_id]
            if self.add_self_loops:
                edge_index_j = torch.cat([edge_index_j, n_id], dim=0)

            self.tmp[new_n_id] = torch.arange(new_n_id.size(0))
            edges[self.j] = self.tmp[edge_index_j]

            edge_index = torch.stack(edges, dim=0)

            e_id = self.e_id[e_id]
            if self.add_self_loops:
                if self.edge_index_loop.size(1) == self.data.num_nodes:
                    # Only set `e_id` if all self-loops were initially passed
                    # to the graph.
                    e_id = torch.cat([e_id, self.e_id_loop[n_id]])
                else:
                    e_id = None
                    if torch_geometric.is_debug_enabled():
                        warnings.warn(
                            ('Could not add edge identifiers to the DataFlow'
                             'object due to missing initial self-loops. '
                             'Please make sure that your graph already '
                             'contains self-loops in case you want to use '
                             'edge-conditioned operators.'))

            n_id = new_n_id

            data_flow.append(n_id, res_n_id, e_id, edge_index)

        return data_flow
Example #5
0
    def from_dict(cls, dictionary):
        r"""Creates a data object from a python dictionary."""
        data = cls()

        for key, item in dictionary.items():
            data[key] = item

        if torch_geometric.is_debug_enabled():
            data.debug()

        return data
Example #6
0
    def forward(self, x, edge_index, pseudo, edge_weight=None):
        """"""
        try:
            import torch_geometric
        except ImportError:  # for debug mode only
            torch_geometric = None
        x = x.unsqueeze(-1) if x.dim() == 1 else x
        pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo

        if not x.is_cuda:
            warnings.warn(
                "We do not recommend using the non-optimized CPU "
                "version of SplineConv. If possible, please convert "
                "your data to the GPU."
            )

        if torch_geometric is not None and torch_geometric.is_debug_enabled():
            if x.size(1) != self.in_channels:
                raise RuntimeError(
                    "Expected {} node features, but found {}".format(
                        self.in_channels, x.size(1)
                    )
                )

            if pseudo.size(1) != self.dim:
                raise RuntimeError(
                    (
                        "Expected pseudo-coordinate dimensionality of {}, but "
                        "found {}"
                    ).format(self.dim, pseudo.size(1))
                )

            min_index, max_index = edge_index.min(), edge_index.max()
            if min_index < 0 or max_index > x.size(0) - 1:
                raise RuntimeError(
                    (
                        "Edge indices must lay in the interval [0, {}]"
                        " but found them in the interval [{}, {}]"
                    ).format(x.size(0) - 1, min_index, max_index)
                )

            min_pseudo, max_pseudo = pseudo.min(), pseudo.max()
            if min_pseudo < 0 or max_pseudo > 1:
                raise RuntimeError(
                    (
                        "Pseudo-coordinates must lay in the fixed interval [0, 1]"
                        " but found them in the interval [{}, {}]"
                    ).format(min_pseudo, max_pseudo)
                )
        x = self.ball_in.logmap0(x)
        return self.propagate(edge_index, x=x, pseudo=pseudo, edge_weight=edge_weight)
Example #7
0
    def __init__(self, x=None, edge_index=None, edge_attr=None, y=None,
                 pos=None, norm=None, face=None, **kwargs):
        self.x = x
        self.edge_index = edge_index
        self.edge_attr = edge_attr
        self.y = y
        self.pos = pos
        self.norm = norm
        self.face = face
        for key, item in kwargs.items():
            if key == 'num_nodes':
                self.__num_nodes__ = item
            else:
                self[key] = item

        if torch_geometric.is_debug_enabled():
            self.debug()
    def __init__(self, img=None, y=None, indices=None, mask=None, **kwargs):
        self.img = img
        self.y = y
        self.indices = indices
        self.mask = mask
        if indices is not None:
            self.num_nodes = self.indices.shape[0]
        else:
            self.num_nodes = None
        for key, item in kwargs.items():
            if key == 'num_nodes':
                self.__num_nodes__ = item
            else:
                self[key] = item

        if torch_geometric.is_debug_enabled():
            self.debug()
Example #9
0
    def forward(self, x, edge_index, pseudo):
        """"""
        x = x.unsqueeze(-1) if x.dim() == 1 else x
        pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo

        if not x.is_cuda:
            warnings.warn('We do not recommend using the non-optimized CPU '
                          'version of SplineConv. If possible, please convert '
                          'your data to the GPU.')

        if torch_geometric.is_debug_enabled():
            if x.size(1) != self.in_channels:
                raise RuntimeError(
                    'Expected {} node features, but found {}'.format(
                        self.in_channels, x.size(1)))

            if pseudo.size(1) != self.dim:
                raise RuntimeError(
                    ('Expected pseudo-coordinate dimensionality of {}, but '
                     'found {}').format(self.dim, pseudo.size(1)))

            min_index, max_index = edge_index.min(), edge_index.max()
            if min_index < 0 or max_index > x.size(0) - 1:
                raise RuntimeError(
                    ('Edge indices must lay in the interval [0, {}]'
                     ' but found them in the interval [{}, {}]').format(
                         x.size(0) - 1, min_index, max_index))

            min_pseudo, max_pseudo = pseudo.min(), pseudo.max()
            if min_pseudo < 0 or max_pseudo > 1:
                raise RuntimeError(
                    ('Pseudo-coordinates must lay in the fixed interval [0, 1]'
                     ' but found them in the interval [{}, {}]').format(
                         min_pseudo, max_pseudo))

        return self.propagate(edge_index, x=x, pseudo=pseudo)
Example #10
0
    def from_data_list(data_list, follow_batch=[]):
        r"""Constructs a batch object from a python list holding
        :class:`torch_geometric.data.Data` objects.
        The assignment vector :obj:`batch` is created on the fly.
        Additionally, creates assignment batch vectors for each key in
        :obj:`follow_batch`."""

        keys = [set(data.keys) for data in data_list]
        keys = list(set.union(*keys))
        assert 'batch' not in keys

        batch = Batch()

        for key in keys:
            batch[key] = []

        for key in follow_batch:
            batch['{}_batch'.format(key)] = []

        cumsum = {}
        batch.batch = []
        for i, data in enumerate(data_list):
            for key in data.keys:
                item = data[key] + cumsum.get(key, 0)
                if key in cumsum:
                    cumsum[key] += data.__inc__(key, item)
                else:
                    cumsum[key] = data.__inc__(key, item)
                batch[key].append(item)

            for key in follow_batch:
                size = data[key].size(data.__cat_dim__(key, data[key]))
                item = torch.full((size, ), i, dtype=torch.long)
                batch['{}_batch'.format(key)].append(item)

            num_nodes = data.num_nodes
            if num_nodes is not None:
                item = torch.full((num_nodes, ), i, dtype=torch.long)
                batch.batch.append(item)

        if num_nodes is None:
            batch.batch = None

        for key in batch.keys:
            item = batch[key][0]
            if torch.is_tensor(item):
                batch[key] = torch.cat(batch[key],
                                       dim=data_list[0].__cat_dim__(key, item))
            elif isinstance(item, int) or isinstance(item, float):
                batch[key] = torch.tensor(batch[key])
            else:
                raise ValueError('Unsupported attribute type.')

        # Copy custom data functions to batch (does not work yet):
        # if data_list.__class__ != Data:
        #     org_funcs = set(Data.__dict__.keys())
        #     funcs = set(data_list[0].__class__.__dict__.keys())
        #     batch.__custom_funcs__ = funcs.difference(org_funcs)
        #     for func in funcs.difference(org_funcs):
        #         setattr(batch, func, getattr(data_list[0], func))

        if torch_geometric.is_debug_enabled():
            batch.debug()

        return batch.contiguous()
Example #11
0
    def from_data_list(data_list, follow_batch=[]):
        r"""Constructs a batch object from a python list holding
        :class:`torch_geometric.data.Data` objects.
        The assignment vector :obj:`batch` is created on the fly.
        Additionally, creates assignment batch vectors for each key in
        :obj:`follow_batch`."""

        keys = [set(data.keys) for data in data_list]
        keys = list(set.union(*keys))
        assert 'batch' not in keys

        batch = MultiScaleBatch()
        batch.__data_class__ = data_list[0].__class__
        batch.__slices__ = {key: [0] for key in keys}

        for key in keys:
            batch[key] = []

        for key in follow_batch:
            batch['{}_batch'.format(key)] = []

        cumsum = {key: 0 for key in keys}
        cumsum4list = {key: [] for key in keys}
        batch.batch = []
        batch.list_batch = []
        for i, data in enumerate(data_list):
            for key in data.keys:
                item = data[key]
                if torch.is_tensor(item) and item.dtype != torch.bool:
                    item = item + cumsum[key]
                if torch.is_tensor(item):
                    size = item.size(data.__cat_dim__(key, data[key]))

                # Here particular case for this kind of list of tensors
                # process the neighbors
                if bool(re.search('(neigh|pool|upsample)', key)):
                    if isinstance(item, list):
                        if (len(cumsum4list[key]) == 0):  # for the first time
                            cumsum4list[key] = torch.zeros(len(item),
                                                           dtype=torch.long)
                        for j in range(len(item)):
                            if (torch.is_tensor(item[j])
                                    and item[j].dtype != torch.bool):
                                item[j][item[j] > 0] += cumsum4list[key][j]
                                # print(key, data["{}_size".format(key)][j])
                                cumsum4list[key][j] += data["{}_size".format(
                                    key)][j]

                else:
                    size = 1
                batch.__slices__[key].append(size + batch.__slices__[key][-1])
                cumsum[key] += data.__inc__(key, item)
                batch[key].append(item)

                if key in follow_batch:
                    item = torch.full((size, ), i, dtype=torch.long)
                    batch['{}_batch'.format(key)].append(item)

            num_nodes = data.num_nodes
            if num_nodes is not None:
                item = torch.full((num_nodes, ), i, dtype=torch.long)
                batch.batch.append(item)

            # indice of the batch at each scale
            if (data.points is not None):
                list_batch = []
                for j in range(len(data['points'])):
                    size = len(data.points[j])
                    item = torch.full((size, ), i, dtype=torch.long)
                    list_batch.append(item)
                batch.list_batch.append(list_batch)

        if num_nodes is None:
            batch.batch = None

        for key in batch.keys:

            item = batch[key][0]

            if torch.is_tensor(item):
                batch[key] = torch.cat(batch[key],
                                       dim=data_list[0].__cat_dim__(key, item))
            elif isinstance(item, int) or isinstance(item, float):
                batch[key] = torch.tensor(batch[key])

            elif isinstance(item, list):
                item = batch[key]
                res = []
                for j in range(len(item[0])):
                    col = [f[j] for f in batch[key]]
                    res.append(
                        torch.cat(col, dim=data_list[0].__cat_dim__(key, col)))
                batch[key] = res
                # print('item', item)

            else:
                raise ValueError('{} is an Unsupported attribute type'.format(
                    type(item)))

        # Copy custom data functions to batch (does not work yet):
        # if data_list.__class__ != Data:
        #     org_funcs = set(Data.__dict__.keys())
        #     funcs = set(data_list[0].__class__.__dict__.keys())
        #     batch.__custom_funcs__ = funcs.difference(org_funcs)
        #     for func in funcs.difference(org_funcs):
        #         setattr(batch, func, getattr(data_list[0], func))

        if torch_geometric.is_debug_enabled():
            batch.debug()

        return batch.contiguous()
Example #12
0
    def from_data_list(data_list, follow_batch=[]):
        r"""Constructs a batch object from a python list holding
        :class:`torch_geometric.data.Data` objects.
        The assignment vector :obj:`batch` is created on the fly.
        Additionally, creates assignment batch vectors for each key in
        :obj:`follow_batch`."""

        keys = [set(data.keys) for data in data_list]
        keys = list(set.union(*keys))
        assert 'batch' not in keys

        batch = Batch()
        for key in data_list[0].__dict__.keys():
            if key[:2] != '__' and key[-2:] != '__':
                batch[key] = None

        batch.__data_class__ = data_list[0].__class__
        for key in keys + ['batch']:
            batch[key] = []

        device = None
        slices = {key: [0] for key in keys}
        cumsum = {key: [0] for key in keys}
        cat_dims = {}
        num_nodes_list = []
        for i, data in enumerate(data_list):
            for key in keys:
                item = data[key]

                # Increase values by `cumsum` value.
                cum = cumsum[key][-1]
                if isinstance(item, Tensor) and item.dtype != torch.bool:
                    if not isinstance(cum, int) or cum != 0:
                        item = item + cum
                elif isinstance(item, SparseTensor):
                    value = item.storage.value()
                    if value is not None and value.dtype != torch.bool:
                        if not isinstance(cum, int) or cum != 0:
                            value = value + cum
                        item = item.set_value(value, layout='coo')
                elif isinstance(item, (int, float)):
                    item = item + cum

                # Treat 0-dimensional tensors as 1-dimensional.
                if isinstance(item, Tensor) and item.dim() == 0:
                    item = item.unsqueeze(0)

                batch[key].append(item)

                # Gather the size of the `cat` dimension.
                size = 1
                cat_dim = data.__cat_dim__(key, data[key])
                cat_dims[key] = cat_dim
                if isinstance(item, Tensor):
                    size = item.size(cat_dim)
                    device = item.device
                elif isinstance(item, SparseTensor):
                    size = torch.tensor(item.sizes())[torch.tensor(cat_dim)]
                    device = item.device()

                slices[key].append(size + slices[key][-1])
                inc = data.__inc__(key, item)
                if isinstance(inc, (tuple, list)):
                    inc = torch.tensor(inc)
                cumsum[key].append(inc + cumsum[key][-1])

                if key in follow_batch:
                    if isinstance(size, Tensor):
                        for j, size in enumerate(size.tolist()):
                            tmp = f'{key}_{j}_batch'
                            batch[tmp] = [] if i == 0 else batch[tmp]
                            batch[tmp].append(
                                torch.full((size, ),
                                           i,
                                           dtype=torch.long,
                                           device=device))
                    else:
                        tmp = f'{key}_batch'
                        batch[tmp] = [] if i == 0 else batch[tmp]
                        batch[tmp].append(
                            torch.full((size, ),
                                       i,
                                       dtype=torch.long,
                                       device=device))

            if hasattr(data, '__num_nodes__'):
                num_nodes_list.append(data.__num_nodes__)
            else:
                num_nodes_list.append(None)

            num_nodes = data.num_nodes
            if num_nodes is not None:
                item = torch.full((num_nodes, ),
                                  i,
                                  dtype=torch.long,
                                  device=device)
                batch.batch.append(item)

        # Fix initial slice values:
        for key in keys:
            slices[key][0] = slices[key][1] - slices[key][1]

        batch.batch = None if len(batch.batch) == 0 else batch.batch
        batch.__slices__ = slices
        batch.__cumsum__ = cumsum
        batch.__cat_dims__ = cat_dims
        batch.__num_nodes_list__ = num_nodes_list

        ref_data = data_list[0]
        for key in batch.keys:
            items = batch[key]
            item = items[0]
            if isinstance(item, Tensor):
                batch[key] = torch.cat(items, ref_data.__cat_dim__(key, item))
            elif isinstance(item, SparseTensor):
                batch[key] = cat(items, ref_data.__cat_dim__(key, item))
            elif isinstance(item, (int, float)):
                batch[key] = torch.tensor(items)

        if torch_geometric.is_debug_enabled():
            batch.debug()

        return batch.contiguous()
Example #13
0
    def from_data_list(data_list, follow_batch=[]):
        r"""Constructs a batch object from a python list holding
        :class:`torch_geometric.data.Data` objects.
        The assignment vector :obj:`batch` is created on the fly.
        Additionally, creates assignment batch vectors for each key in
        :obj:`follow_batch`."""

        keys = [set(data.keys) for data in data_list]
        keys = set.union(*keys)
        keys.remove('depth_count')
        keys = list(keys)
        depth = max(data.depth_count.shape[0] for data in data_list)
        assert 'batch' not in keys

        batch = Batch()
        batch.__data_class__ = data_list[0].__class__
        batch.__slices__ = {key: [0] for key in keys}

        for key in keys:
            batch[key] = []

        for key in follow_batch:
            batch['{}_batch'.format(key)] = []

        cumsum = {i: 0 for i in range(depth)}
        depth_count = th.zeros((depth, ), dtype=th.long)
        batch.batch = []
        for i, data in enumerate(data_list):
            edges = data['edge_index']
            for d in range(1, depth):
                mask = data.depth_mask == d
                edges[mask] += cumsum[d - 1]
                cumsum[d - 1] += data.depth_count[d - 1].item()
            batch['edge_index'].append(edges)
            depth_count += data['depth_count']

            for key in data.keys:
                if key == 'edge_index' or key == 'depth_count':
                    continue
                item = data[key]
                batch[key].append(item)

            num_nodes = data.num_nodes
            if num_nodes is not None:
                item = torch.full((num_nodes, ), i, dtype=torch.long)
                batch.batch.append(item)

        if num_nodes is None:
            batch.batch = None

        for key in batch.keys:
            item = batch[key][0]
            if torch.is_tensor(item):
                batch[key] = torch.cat(batch[key],
                                       dim=data_list[0].__cat_dim__(key, item))
            elif isinstance(item, int) or isinstance(item, float):
                batch[key] = torch.tensor(batch[key])

        batch.depth_count = depth_count
        if torch_geometric.is_debug_enabled():
            batch.debug()

        return batch.contiguous()
Example #14
0
    def from_data_list(data_list, follow_batch=[]):
        r"""Constructs a batch object from a python list holding
        :class:`torch_geometric.data.Data` objects.
        The assignment vector :obj:`batch` is created on the fly.
        Additionally, creates assignment batch vectors for each key in
        :obj:`follow_batch`."""

        keys = [set(data.keys) for data in data_list]
        keys = list(set.union(*keys))
        assert 'batch' not in keys

        batch = Batch()
        batch.__data_class__ = data_list[0].__class__
        for key in keys + ['batch']:
            batch[key] = []

        slices = {key: [0] for key in keys}
        cumsum = {key: [0] for key in keys}
        cat_dims = {}
        num_nodes_list = []
        mm = 0
        for i, data in enumerate(data_list):
            for key in keys:
                item = data[key]

                # Increase values by `cumsum` value.
                cum = cumsum[key][-1]
                #  DAGNN
                if key in ["bi_layer_index", "bi_layer_parent_index"]:
                    # if key == "bi_layer_index":
                    #     # print("now")
                    #     mm = torch.max(item).item()
                    #     print(mm, data.x.shape[0])
                    item[:, 1] = item[:, 1] + cum
                    # if key == "bi_layer_index":
                    #     # print("now")
                    #     mm = torch.max(item).item()
                    #     print(mm) #, data.x.shape[0])
                elif key in ["layer_index", "layer_parent_index"]:
                    # keep layer dimension 0 and only increase ids in dimension 1 (we merge layers of all data objects)
                    item[1] = item[1] + cum
                elif isinstance(item, Tensor) and item.dtype != torch.bool:
                    item = item + cum if cum != 0 else item
                elif isinstance(item, SparseTensor):
                    value = item.storage.value()
                    if value is not None and value.dtype != torch.bool:
                        value = value + cum if cum != 0 else value
                        item = item.set_value(value, layout='coo')
                elif isinstance(item, (int, float)):
                    item = item + cum

                # Treat 0-dimensional tensors as 1-dimensional.
                if isinstance(item, Tensor) and item.dim() == 0:
                    item = item.unsqueeze(0)

                batch[key].append(item)

                # Gather the size of the `cat` dimension.
                size = 1
                cat_dim = data.__cat_dim__(key, data[key])
                cat_dims[key] = cat_dim
                if isinstance(item, Tensor):
                    size = item.size(cat_dim)
                elif isinstance(item, SparseTensor):
                    size = torch.tensor(item.sizes())[torch.tensor(cat_dim)]

                slices[key].append(size + slices[key][-1])
                inc = data.__inc__(key, item)
                if isinstance(inc, (tuple, list)):
                    inc = torch.tensor(inc)
                cumsum[key].append(inc + cumsum[key][-1])

                if key in follow_batch:
                    if isinstance(size, Tensor):
                        for j, size in enumerate(size.tolist()):
                            tmp = f'{key}_{j}_batch'
                            batch[tmp] = [] if i == 0 else batch[tmp]
                            batch[tmp].append(
                                torch.full((size, ), i, dtype=torch.long))
                    else:
                        tmp = f'{key}_batch'
                        batch[tmp] = [] if i == 0 else batch[tmp]
                        batch[tmp].append(
                            torch.full((size, ), i, dtype=torch.long))

            if hasattr(data, '__num_nodes__'):
                num_nodes_list.append(data.__num_nodes__)
            else:
                num_nodes_list.append(None)

            num_nodes = data.num_nodes
            # print(mm, num_nodes, sum(num_nodes_list))
            if num_nodes is not None:
                item = torch.full((num_nodes, ), i, dtype=torch.long)
                batch.batch.append(item)

        # Fix initial slice values:
        for key in keys:
            slices[key][0] = slices[key][1] - slices[key][1]

        batch.batch = None if len(batch.batch) == 0 else batch.batch
        batch.__slices__ = slices
        batch.__cumsum__ = cumsum
        batch.__cat_dims__ = cat_dims
        batch.__num_nodes_list__ = num_nodes_list

        ref_data = data_list[0]
        for key in batch.keys:
            items = batch[key]
            item = items[0]
            if isinstance(item, Tensor):
                batch[key] = torch.cat(items, ref_data.__cat_dim__(key, item))
            elif isinstance(item, SparseTensor):
                batch[key] = cat(items, ref_data.__cat_dim__(key, item))
            elif isinstance(item, (int, float)):
                batch[key] = torch.tensor(items)

        if torch_geometric.is_debug_enabled():
            batch.debug()

        return batch.contiguous()
Example #15
0
    def from_data_list(data_list, follow_batch=[]):
        r"""Constructs a batch object from a python list holding
        :class:`torch_geometric.data.Data` objects.
        The assignment vector :obj:`batch` is created on the fly.
        Additionally, creates assignment batch vectors for each key in
        :obj:`follow_batch`."""
        keys = [set(data.keys) for data in data_list]
        keys = list(set.union(*keys))
        assert 'batch' not in keys
        batch = Batch()
        batch.__data_class__ = data_list[0].__class__
        batch.__slices__ = {key: [0] for key in keys}
        for key in keys:
            batch[key] = []
        for key in follow_batch:
            batch['{}_batch'.format(key)] = []
        cumsum = {key: 0 for key in keys}
        batch.batch = []
        for i, data in enumerate(data_list):
            for key in data.keys:
                # logger.info(f"key={key}")
                item = data[key]
                if torch.is_tensor(item) and item.dtype != torch.bool:
                    item = item + cumsum[key]
                if torch.is_tensor(item):
                    size = item.size(data.__cat_dim__(key, data[key]))
                else:
                    size = 1
                batch.__slices__[key].append(size + batch.__slices__[key][-1])
                cumsum[key] = cumsum[key] + data.__inc__(key, item)
                batch[key].append(item)

                if key in follow_batch:
                    item = torch.full((size,), i, dtype=torch.long)
                    batch['{}_batch'.format(key)].append(item)

            num_nodes = data.num_nodes
            if num_nodes is not None:
                item = torch.full((num_nodes,), i, dtype=torch.long)
                batch.batch.append(item)

        if num_nodes is None:
            batch.batch = None

        for key in batch.keys:
            item = batch[key][0]
            logger.debug(f"key = {key}")
            if torch.is_tensor(item):
                logger.debug(f"batch[{key}]")
                logger.debug(f"item.shape = {item.shape}")
                elem = data_list[0]     # type(elem) = Data or ClevrData
                dim_ = elem.__cat_dim__(key, item)      # basically, which dim we want to concat
                batch[key] = torch.cat(batch[key], dim=dim_)
                # batch[key] = torch.cat(batch[key],
                #                        dim=data_list[0].__cat_dim__(key, item))
            elif isinstance(item, int) or isinstance(item, float):
                batch[key] = torch.tensor(batch[key])
        if torch_geometric.is_debug_enabled():
            batch.debug()

        return batch.contiguous()
Example #16
0
    def from_data_list(cls, data_list, follow_batch=[], exclude_keys=[]):
        r"""Constructs a batch object from a python list holding
        :class:`torch_geometric.data.Data` objects.
        The assignment vector :obj:`batch` is created on the fly.
        Additionally, creates assignment batch vectors for each key in
        :obj:`follow_batch`.
        Will exclude any keys given in :obj:`exclude_keys`."""

        keys = set(data_list[0].keys)
        keys -= {'num_nodes', 'num_edges'}
        keys -= set(exclude_keys)
        keys = sorted(list(keys))
        assert 'batch' not in keys and 'ptr' not in keys

        batch = cls()
        batch.__num_graphs__ = len(data_list)
        batch.__data_class__ = data_list[0].__class__
        for key in keys + ['batch']:
            batch[key] = []
        batch.ptr = [0]

        device = None
        slices = {key: [0] for key in keys}
        cumsum = {key: [0] for key in keys}
        cat_dims = {}
        num_nodes_list = []
        for i, data in enumerate(data_list):
            for key in keys:
                item = data[key]

                # Increase values by `cumsum` value.
                cum = cumsum[key][-1]
                if isinstance(item, Tensor) and item.dtype != torch.bool:
                    if not isinstance(cum, int) or cum != 0:
                        item = item + cum
                elif isinstance(item, SparseTensor):
                    value = item.storage.value()
                    if value is not None and value.dtype != torch.bool:
                        if not isinstance(cum, int) or cum != 0:
                            value = value + cum
                        item = item.set_value(value, layout='coo')
                elif isinstance(item, (int, float)):
                    item = item + cum

                # Gather the size of the `cat` dimension.
                size = 1
                cat_dim = data.__cat_dim__(key, data[key])
                # 0-dimensional tensors have no dimension along which to
                # concatenate, so we set `cat_dim` to `None`.
                if isinstance(item, Tensor) and item.dim() == 0:
                    cat_dim = None
                cat_dims[key] = cat_dim

                # Add a batch dimension to items whose `cat_dim` is `None`:
                if isinstance(item, Tensor) and cat_dim is None:
                    cat_dim = 0  # Concatenate along this new batch dimension.
                    item = item.unsqueeze(0)
                    device = item.device
                elif isinstance(item, Tensor):
                    size = item.size(cat_dim)
                    device = item.device
                elif isinstance(item, SparseTensor):
                    size = torch.tensor(item.sizes())[torch.tensor(cat_dim)]
                    device = item.device()

                batch[key].append(item)  # Append item to the attribute list.

                slices[key].append(size + slices[key][-1])
                inc = data.__inc__(key, item)
                if isinstance(inc, (tuple, list)):
                    inc = torch.tensor(inc)
                cumsum[key].append(inc + cumsum[key][-1])

                if key in follow_batch:
                    if isinstance(size, Tensor):
                        for j, size in enumerate(size.tolist()):
                            tmp = f'{key}_{j}_batch'
                            batch[tmp] = [] if i == 0 else batch[tmp]
                            batch[tmp].append(
                                torch.full((size, ),
                                           i,
                                           dtype=torch.long,
                                           device=device))
                    else:
                        tmp = f'{key}_batch'
                        batch[tmp] = [] if i == 0 else batch[tmp]
                        batch[tmp].append(
                            torch.full((size, ),
                                       i,
                                       dtype=torch.long,
                                       device=device))

            if 'num_nodes' in data:
                num_nodes_list.append(data.num_nodes)
            else:
                num_nodes_list.append(None)

            num_nodes = data.num_nodes
            if num_nodes is not None:
                item = torch.full((num_nodes, ),
                                  i,
                                  dtype=torch.long,
                                  device=device)
                batch.batch.append(item)
                batch.ptr.append(batch.ptr[-1] + num_nodes)

        batch.batch = None if len(batch.batch) == 0 else batch.batch
        batch.ptr = None if len(batch.ptr) == 1 else batch.ptr
        batch.__slices__ = slices
        batch.__cumsum__ = cumsum
        batch.__cat_dims__ = cat_dims
        batch.__num_nodes_list__ = num_nodes_list

        ref_data = data_list[0]
        for key in batch.keys:
            items = batch[key]
            item = items[0]
            cat_dim = ref_data.__cat_dim__(key, item)
            cat_dim = 0 if cat_dim is None else cat_dim
            if isinstance(item, Tensor):
                batch[key] = torch.cat(items, cat_dim)
            elif isinstance(item, SparseTensor):
                batch[key] = cat(items, cat_dim)
            elif isinstance(item, (int, float)):
                batch[key] = torch.tensor(items)

        if torch_geometric.is_debug_enabled():
            batch.debug()

        return batch
Example #17
0
    def from_data_list(data_list, follow_batch=[]):
        r"""Constructs a batch object from a python list holding
        :class:`torch_geometric.data.Data` objects.
        The assignment vector :obj:`batch` is created on the fly.
        Additionally, creates assignment batch vectors for each key in
        :obj:`follow_batch`."""

        keys = [set(data.keys) for data in data_list]
        keys = list(set.union(*keys))
        assert 'batch' not in keys

        batch = Batch_custom()
        batch.__data_class__ = data_list[0].__class__
        batch.__slices__ = {key: [0] for key in keys}
        
        for key in keys:
            batch[key] = []

        for key in follow_batch:
            batch['{}_batch'.format(key)] = []

        cumsum = {key: 0 for key in keys}
        batch.batch = []
        # we have a bipartite graph, so keep track of batches for each set of nodes
        batch.batch_factors = [] #for factor beliefs
        batch.batch_vars = [] #for variable beliefs
        junk_bin_val = 0
        for i, data in enumerate(data_list):
            for key in data.keys:
                item = data[key]
                if torch.is_tensor(item) and item.dtype != torch.bool:
#                     print("key:", key)
#                     print("item:", item)
#                     print("cumsum[key]:", cumsum[key])
#                     print()

                    if key == "facStates_to_varIdx":
#                         print("a item:", item)
#                         print("a item.shape:", item.shape)
#                         print("cumsum[key]:", cumsum[key])                
#                         print("cumsum[key].shape:", cumsum[key].shape)
#                         item = item + cumsum[key]
                        item = item.clone() #without this we edit the data for the next epoch, causing errors
                        item[torch.where(item != -1)] = item[torch.where(item != -1)] + cumsum[key]
#                         print("b item:", item)    
#                         print("b item.shape:", item.shape)    
                    else:
                        item = item + cumsum[key]
                if torch.is_tensor(item):
                    size = item.size(data.__cat_dim__(key, data[key]))
                else:
                    size = 1
                batch.__slices__[key].append(size + batch.__slices__[key][-1])
                if key == "facStates_to_varIdx":
                    facStates_to_varIdx_inc = data.__inc__(key, item)
                    junk_bin_val += facStates_to_varIdx_inc
                    cumsum[key] += facStates_to_varIdx_inc                   
                else:
                    cumsum[key] += data.__inc__(key, item)
                batch[key].append(item)
                
#                 if key == 'edge_index' or key == 'facToVar_edge_idx':
#                     print("key:", key)
#                     print("cumsum[key]:", cumsum[key])
#                     print()
                    
                if key in follow_batch:
                    item = torch.full((size, ), i, dtype=torch.long)
                    batch['{}_batch'.format(key)].append(item)

            num_nodes = data.num_nodes
            if num_nodes is not None:
                item = torch.full((num_nodes, ), i, dtype=torch.long)
                batch.batch.append(item)
                print()
                print("num_nodes item:", item)
                print()
                
            batch.batch_factors.append(torch.full((data.numFactors, ), i, dtype=torch.long))
            batch.batch_vars.append(torch.full((data.numVars, ), i, dtype=torch.long))            

        if num_nodes is None:
            batch.batch = None
            
#         print("1 batch.num_nodes:", batch.num_nodes)
#         print("000 type(batch.num_nodes):", type(batch.num_nodes))

        for key in batch.keys:
#             print()
#             print("key:", key)
            item = batch[key][0]
            if torch.is_tensor(item):                   
                batch[key] = torch.cat(batch[key],
                                       dim=data_list[0].__cat_dim__(key, item))               
                if key == "facStates_to_varIdx":
                    batch[key][torch.where(batch[key] == -1)] = junk_bin_val   
            elif isinstance(item, int) or isinstance(item, float):               
                batch[key] = torch.tensor(batch[key])               
            else:
                raise ValueError('Unsupported attribute type')
                                  
######jdk debugging
#             if key == 'num_nodes':
#                 print("key: num_nodes, batch[key]:", batch[key])
            
#             try:
#                 print("key:", key, "batch.num_nodes:", batch.num_nodes)
#             except:
#                 print("batch.num_nodes doesn't work yet")
            

######end jdk debugging
                
        # Copy custom data functions to batch (does not work yet):
        # if data_list.__class__ != Data:
        #     org_funcs = set(Data.__dict__.keys())
        #     funcs = set(data_list[0].__class__.__dict__.keys())
        #     batch.__custom_funcs__ = funcs.difference(org_funcs)
        #     for func in funcs.difference(org_funcs):
        #         setattr(batch, func, getattr(data_list[0], func))

        if torch_geometric.is_debug_enabled():
            batch.debug()
        
        return batch.contiguous()
Example #18
0
def test_debug():
    assert is_debug_enabled() is False
    set_debug(True)
    assert is_debug_enabled() is True
    set_debug(False)
    assert is_debug_enabled() is False

    assert is_debug_enabled() is False
    with set_debug(True):
        assert is_debug_enabled() is True
    assert is_debug_enabled() is False

    assert is_debug_enabled() is False
    set_debug(True)
    assert is_debug_enabled() is True
    with set_debug(False):
        assert is_debug_enabled() is False
    assert is_debug_enabled() is True
    set_debug(False)
    assert is_debug_enabled() is False

    assert is_debug_enabled() is False
    with debug():
        assert is_debug_enabled() is True
    assert is_debug_enabled() is False