Example #1
0
def from_data_list_token(data_list, follow_batch=[]):
    """ This is pretty a copy paste of the from data list of pytorch geometric
    batch object with the difference that indexes that are negative are not incremented
    """

    keys = [set(data.keys) for data in data_list]
    keys = list(set.union(*keys))
    assert "batch" not in keys

    batch = Batch()
    batch.__data_class__ = data_list[0].__class__
    batch.__slices__ = {key: [0] for key in keys}

    for key in keys:
        batch[key] = []

    for key in follow_batch:
        batch["{}_batch".format(key)] = []

    cumsum = {key: 0 for key in keys}
    batch.batch = []
    for i, data in enumerate(data_list):
        for key in data.keys:
            item = data[key]
            if torch.is_tensor(item) and item.dtype != torch.bool:
                mask = item >= 0
                item[mask] = item[mask] + cumsum[key]
            if torch.is_tensor(item):
                size = item.size(data.__cat_dim__(key, data[key]))
            else:
                size = 1
            batch.__slices__[key].append(size + batch.__slices__[key][-1])
            cumsum[key] += data.__inc__(key, item)
            batch[key].append(item)

            if key in follow_batch:
                item = torch.full((size,), i, dtype=torch.long)
                batch["{}_batch".format(key)].append(item)

        num_nodes = data.num_nodes
        if num_nodes is not None:
            item = torch.full((num_nodes,), i, dtype=torch.long)
            batch.batch.append(item)

    if num_nodes is None:
        batch.batch = None

    for key in batch.keys:
        item = batch[key][0]
        if torch.is_tensor(item):
            batch[key] = torch.cat(batch[key], dim=data_list[0].__cat_dim__(key, item))
        elif isinstance(item, int) or isinstance(item, float):
            batch[key] = torch.tensor(batch[key])
        else:
            raise ValueError("Unsupported attribute type {} : {}".format(type(item), item))

    if torch_geometric.is_debug_enabled():
        batch.debug()

    return batch.contiguous()
Example #2
0
def collate_fn_withpad(data_list):
    '''
    Modified based on PyTorch-Geometric's implementation
    :param data_list:
    :return:
    '''
    keys = [set(data.keys) for data in data_list]
    keys = list(set.union(*keys))
    assert 'batch' not in keys

    batch = Batch()

    for key in keys:
        batch[key] = []
    batch.batch = []

    cumsum = 0
    for i, data in enumerate(data_list):
        num_nodes = data.num_nodes
        batch.batch.append(torch.full((num_nodes, ), i, dtype=torch.long))
        for key in data.keys:
            item = data[key]
            item = item + cumsum if data.__cumsum__(key, item) else item
            batch[key].append(item)
        cumsum += num_nodes

    for key in keys:
        item = batch[key][0]
        if torch.is_tensor(item):
            if (len(item.shape) == 3):
                tlens = [x.shape[1] for x in batch[key]]
                maxtlens = np.max(tlens)
                to_cat = []
                for x in batch[key]:
                    to_cat.append(
                        torch.cat([
                            x,
                            x.new_zeros(x.shape[0], maxtlens - x.shape[1],
                                        x.shape[2])
                        ],
                                  dim=1))
                batch[key] = torch.cat(to_cat, dim=0)
                if 'tlens' not in batch.keys:
                    batch['tlens'] = item.new_tensor(tlens, dtype=torch.long)
            else:
                batch[key] = torch.cat(batch[key],
                                       dim=data_list[0].__cat_dim__(key, item))
        elif isinstance(item, int) or isinstance(item, float):
            batch[key] = torch.tensor(batch[key])
        else:
            raise ValueError('Unsupported attribute type.')
    batch.batch = torch.cat(batch.batch, dim=-1)
    return batch.contiguous()
Example #3
0
def collate(data_list):
    keys = data_list[0].keys
    assert 'batch' not in keys

    batch = Batch()
    for key in keys:
        batch[key] = []
    batch.batch = []
    if 'edge_index_2' in keys:
        batch.batch_2 = []
    if 'edge_index_3' in keys:
        batch.batch_3 = []

    keys.remove('edge_index')
    props = [
        'edge_index_2', 'assignment_2', 'edge_index_3', 'assignment_3',
        'assignment_2to3'
    ]
    keys = [x for x in keys if x not in props]

    cumsum_1 = N_1 = cumsum_2 = N_2 = cumsum_3 = N_3 = 0

    for i, data in enumerate(data_list):
        for key in keys:
            batch[key].append(data[key])

        N_1 = data.num_nodes
        batch.edge_index.append(data.edge_index + cumsum_1)
        batch.batch.append(torch.full((N_1, ), i, dtype=torch.long))

        if 'edge_index_2' in data:
            N_2 = data.assignment_2[1].max().item() + 1
            batch.edge_index_2.append(data.edge_index_2 + cumsum_2)
            batch.assignment_2.append(data.assignment_2 +
                                      torch.tensor([[cumsum_1], [cumsum_2]]))
            batch.batch_2.append(torch.full((N_2, ), i, dtype=torch.long))

        if 'edge_index_3' in data:
            N_3 = data.assignment_3[1].max().item() + 1
            batch.edge_index_3.append(data.edge_index_3 + cumsum_3)
            batch.assignment_3.append(data.assignment_3 +
                                      torch.tensor([[cumsum_1], [cumsum_3]]))
            batch.batch_3.append(torch.full((N_3, ), i, dtype=torch.long))

        if 'assignment_2to3' in data:
            assert 'edge_index_2' in data and 'edge_index_3' in data
            batch.assignment_2to3.append(
                data.assignment_2to3 + torch.tensor([[cumsum_2], [cumsum_3]]))

        cumsum_1 += N_1
        cumsum_2 += N_2
        cumsum_3 += N_3

    keys = [x for x in batch.keys if x not in ['batch', 'batch_2', 'batch_3']]
    for key in keys:
        batch[key] = torch.cat(batch[key], dim=data_list[0].cat_dim(key))

    batch.batch = torch.cat(batch.batch, dim=-1)

    if 'batch_2' in batch:
        batch.batch_2 = torch.cat(batch.batch_2, dim=-1)

    if 'batch_3' in batch:
        batch.batch_3 = torch.cat(batch.batch_3, dim=-1)

    return batch.contiguous()
Example #4
0
    def from_data_list(data_list, follow_batch=[]):
        r"""Constructs a batch object from a python list holding
        :class:`torch_geometric.data.Data` objects.
        The assignment vector :obj:`batch` is created on the fly.
        Additionally, creates assignment batch vectors for each key in
        :obj:`follow_batch`."""

        keys = [set(data.keys) for data in data_list]
        keys = set.union(*keys)
        keys.remove('depth_count')
        keys = list(keys)
        depth = max(data.depth_count.shape[0] for data in data_list)
        assert 'batch' not in keys

        batch = Batch()
        batch.__data_class__ = data_list[0].__class__
        batch.__slices__ = {key: [0] for key in keys}

        for key in keys:
            batch[key] = []

        for key in follow_batch:
            batch['{}_batch'.format(key)] = []

        cumsum = {i: 0 for i in range(depth)}
        depth_count = th.zeros((depth, ), dtype=th.long)
        batch.batch = []
        for i, data in enumerate(data_list):
            edges = data['edge_index']
            for d in range(1, depth):
                mask = data.depth_mask == d
                edges[mask] += cumsum[d - 1]
                cumsum[d - 1] += data.depth_count[d - 1].item()
            batch['edge_index'].append(edges)
            depth_count += data['depth_count']

            for key in data.keys:
                if key == 'edge_index' or key == 'depth_count':
                    continue
                item = data[key]
                batch[key].append(item)

            num_nodes = data.num_nodes
            if num_nodes is not None:
                item = torch.full((num_nodes, ), i, dtype=torch.long)
                batch.batch.append(item)

        if num_nodes is None:
            batch.batch = None

        for key in batch.keys:
            item = batch[key][0]
            if torch.is_tensor(item):
                batch[key] = torch.cat(batch[key],
                                       dim=data_list[0].__cat_dim__(key, item))
            elif isinstance(item, int) or isinstance(item, float):
                batch[key] = torch.tensor(batch[key])

        batch.depth_count = depth_count
        if torch_geometric.is_debug_enabled():
            batch.debug()

        return batch.contiguous()
Example #5
0
def from_data_list(data_list, follow_batch=[]):
    r"""Constructs a batch object from a python list holding
    :class:`torch_geometric.data.Data` objects.
    The assignment vector :obj:`batch` is created on the fly.
    Additionally, creates assignment batch vectors for each key in
    :obj:`follow_batch`."""

    keys = [set(data.keys) for data in data_list]
    keys = list(set.union(*keys))
    assert 'batch' not in keys

    batch = Batch()

    for key in keys:
        batch[key] = []

    for key in follow_batch:
        batch['{}_batch'.format(key)] = []

    cumsum = {}
    batch.batch = []
    for i, data in enumerate(data_list):
        for key in data.keys:
            item = data[key] + cumsum.get(key, 0)
            if key in cumsum:
                if key == 'edge_index':
                    cumsum[key] += data['x_numeric'].shape[0]
                
                if key == 'subgraphs_nodes':
                    cumsum[key] += data['x_numeric'].shape[0]
                if key == 'subgraphs_edges':
                    cumsum[key] += data['edge_attr_numeric'].shape[0]
                if key == 'subgraphs_connectivity':
                    cumsum[key] += data['subgraphs_nodes'].shape[0]
                    
                if key == 'subgraphs_nodes_path_batch':
                    cumsum[key] += data['subgraphs_nodes_path_batch'].max() + 1
                if key == 'subgraphs_edges_path_batch':
                    cumsum[key] += data['subgraphs_edges_path_batch'].max() + 1
                if key == 'subgraphs_global_path_batch':
                    cumsum[key] += 1
                    
                if key == 'cycles_id':
                    if data['cycles_id'].shape[0] > 0:
                        cumsum[key] += data['cycles_id'].max() + 1
                        
                if key == 'cycles_edge_index':
                    cumsum[key] += data['edge_attr_numeric'].shape[0]
                    
                if key == 'edges_connectivity_ids':
                    cumsum[key] += data['edge_attr_numeric'].shape[0]
                    
            else:
                if key == 'edge_index':
                    cumsum[key] = data['x_numeric'].shape[0]
                
                if key == 'subgraphs_nodes':
                    cumsum[key] = data['x_numeric'].shape[0]
                if key == 'subgraphs_edges':
                    cumsum[key] = data['edge_attr_numeric'].shape[0]
                if key == 'subgraphs_connectivity':
                    cumsum[key] = data['subgraphs_nodes'].shape[0]
                
                
                if key == 'subgraphs_nodes_path_batch':
                    cumsum[key] = data['subgraphs_nodes_path_batch'].max() + 1
                if key == 'subgraphs_edges_path_batch':
                    cumsum[key] = data['subgraphs_edges_path_batch'].max() + 1
                if key == 'subgraphs_global_path_batch':
                    cumsum[key] = 1
                    
                    
                if key == 'cycles_id':
                    if data['cycles_id'].shape[0] > 0:
                        cumsum[key] = data['cycles_id'].max() + 1
                        
                if key == 'cycles_edge_index':
                    cumsum[key] = data['edge_attr_numeric'].shape[0]
                    
                if key == 'edges_connectivity_ids':
                    cumsum[key] = data['edge_attr_numeric'].shape[0]


            batch[key].append(item)
            
        for key in follow_batch:
            size = data[key].size(data.__cat_dim__(key, data[key]))
            item = torch.full((size, ), i, dtype=torch.long)
            batch['{}_batch'.format(key)].append(item)

        num_nodes = data.num_nodes
        if num_nodes is not None:
            item = torch.full((num_nodes, ), i, dtype=torch.long)
            batch.batch.append(item)

    if num_nodes is None:
        batch.batch = None

    for key in batch.keys:
        item = batch[key][0]
        if torch.is_tensor(item):
            if key not in ['subgraphs_connectivity', 'edges_connectivity_ids']:
                batch[key] = torch.cat(
                    batch[key], dim=data_list[0].__cat_dim__(key, item))
            else:
                batch[key] = torch.cat(
                    batch[key], dim=-1)
        elif isinstance(item, int) or isinstance(item, float):
            batch[key] = torch.tensor(batch[key])
        else:
            raise ValueError('Unsupported attribute type.')

    # Copy custom data functions to batch (does not work yet):
    # if data_list.__class__ != Data:
    #     org_funcs = set(Data.__dict__.keys())
    #     funcs = set(data_list[0].__class__.__dict__.keys())
    #     batch.__custom_funcs__ = funcs.difference(org_funcs)
    #     for func in funcs.difference(org_funcs):
    #         setattr(batch, func, getattr(data_list[0], func))

    if torch_geometric.is_debug_enabled():
        batch.debug()

    return batch.contiguous()
Example #6
0
    def from_data_list(data_list, follow_batch=[]):
        r"""Constructs a batch object from a python list holding
        :class:`torch_geometric.data.Data` objects.
        The assignment vector :obj:`batch` is created on the fly.
        Additionally, creates assignment batch vectors for each key in
        :obj:`follow_batch`."""
        keys = [set(data.keys) for data in data_list]
        keys = list(set.union(*keys))
        assert 'batch' not in keys
        batch = Batch()
        batch.__data_class__ = data_list[0].__class__
        batch.__slices__ = {key: [0] for key in keys}
        for key in keys:
            batch[key] = []
        for key in follow_batch:
            batch['{}_batch'.format(key)] = []
        cumsum = {key: 0 for key in keys}
        batch.batch = []
        for i, data in enumerate(data_list):
            for key in data.keys:
                # logger.info(f"key={key}")
                item = data[key]
                if torch.is_tensor(item) and item.dtype != torch.bool:
                    item = item + cumsum[key]
                if torch.is_tensor(item):
                    size = item.size(data.__cat_dim__(key, data[key]))
                else:
                    size = 1
                batch.__slices__[key].append(size + batch.__slices__[key][-1])
                cumsum[key] = cumsum[key] + data.__inc__(key, item)
                batch[key].append(item)

                if key in follow_batch:
                    item = torch.full((size,), i, dtype=torch.long)
                    batch['{}_batch'.format(key)].append(item)

            num_nodes = data.num_nodes
            if num_nodes is not None:
                item = torch.full((num_nodes,), i, dtype=torch.long)
                batch.batch.append(item)

        if num_nodes is None:
            batch.batch = None

        for key in batch.keys:
            item = batch[key][0]
            logger.debug(f"key = {key}")
            if torch.is_tensor(item):
                logger.debug(f"batch[{key}]")
                logger.debug(f"item.shape = {item.shape}")
                elem = data_list[0]     # type(elem) = Data or ClevrData
                dim_ = elem.__cat_dim__(key, item)      # basically, which dim we want to concat
                batch[key] = torch.cat(batch[key], dim=dim_)
                # batch[key] = torch.cat(batch[key],
                #                        dim=data_list[0].__cat_dim__(key, item))
            elif isinstance(item, int) or isinstance(item, float):
                batch[key] = torch.tensor(batch[key])
        if torch_geometric.is_debug_enabled():
            batch.debug()

        return batch.contiguous()