Exemplo n.º 1
0
def from_data_list_token(data_list, follow_batch=[]):
    """ This is pretty a copy paste of the from data list of pytorch geometric
    batch object with the difference that indexes that are negative are not incremented
    """

    keys = [set(data.keys) for data in data_list]
    keys = list(set.union(*keys))
    assert "batch" not in keys

    batch = Batch()
    batch.__data_class__ = data_list[0].__class__
    batch.__slices__ = {key: [0] for key in keys}

    for key in keys:
        batch[key] = []

    for key in follow_batch:
        batch["{}_batch".format(key)] = []

    cumsum = {key: 0 for key in keys}
    batch.batch = []
    for i, data in enumerate(data_list):
        for key in data.keys:
            item = data[key]
            if torch.is_tensor(item) and item.dtype != torch.bool:
                mask = item >= 0
                item[mask] = item[mask] + cumsum[key]
            if torch.is_tensor(item):
                size = item.size(data.__cat_dim__(key, data[key]))
            else:
                size = 1
            batch.__slices__[key].append(size + batch.__slices__[key][-1])
            cumsum[key] += data.__inc__(key, item)
            batch[key].append(item)

            if key in follow_batch:
                item = torch.full((size,), i, dtype=torch.long)
                batch["{}_batch".format(key)].append(item)

        num_nodes = data.num_nodes
        if num_nodes is not None:
            item = torch.full((num_nodes,), i, dtype=torch.long)
            batch.batch.append(item)

    if num_nodes is None:
        batch.batch = None

    for key in batch.keys:
        item = batch[key][0]
        if torch.is_tensor(item):
            batch[key] = torch.cat(batch[key], dim=data_list[0].__cat_dim__(key, item))
        elif isinstance(item, int) or isinstance(item, float):
            batch[key] = torch.tensor(batch[key])
        else:
            raise ValueError("Unsupported attribute type {} : {}".format(type(item), item))

    if torch_geometric.is_debug_enabled():
        batch.debug()

    return batch.contiguous()
Exemplo n.º 2
0
def collate_fn_withpad(data_list):
    '''
    Modified based on PyTorch-Geometric's implementation
    :param data_list:
    :return:
    '''
    keys = [set(data.keys) for data in data_list]
    keys = list(set.union(*keys))
    assert 'batch' not in keys

    batch = Batch()

    for key in keys:
        batch[key] = []
    batch.batch = []

    cumsum = 0
    for i, data in enumerate(data_list):
        num_nodes = data.num_nodes
        batch.batch.append(torch.full((num_nodes, ), i, dtype=torch.long))
        for key in data.keys:
            item = data[key]
            item = item + cumsum if data.__cumsum__(key, item) else item
            batch[key].append(item)
        cumsum += num_nodes

    for key in keys:
        item = batch[key][0]
        if torch.is_tensor(item):
            if (len(item.shape) == 3):
                tlens = [x.shape[1] for x in batch[key]]
                maxtlens = np.max(tlens)
                to_cat = []
                for x in batch[key]:
                    to_cat.append(
                        torch.cat([
                            x,
                            x.new_zeros(x.shape[0], maxtlens - x.shape[1],
                                        x.shape[2])
                        ],
                                  dim=1))
                batch[key] = torch.cat(to_cat, dim=0)
                if 'tlens' not in batch.keys:
                    batch['tlens'] = item.new_tensor(tlens, dtype=torch.long)
            else:
                batch[key] = torch.cat(batch[key],
                                       dim=data_list[0].__cat_dim__(key, item))
        elif isinstance(item, int) or isinstance(item, float):
            batch[key] = torch.tensor(batch[key])
        else:
            raise ValueError('Unsupported attribute type.')
    batch.batch = torch.cat(batch.batch, dim=-1)
    return batch.contiguous()
Exemplo n.º 3
0
 def collate(data_list):
     batch = Batch()
     batch.batch = []
     for key in data_list[0].keys:
         batch[key] = default_collate([d[key] for d in data_list])
     for i, data in enumerate(data_list):
         num_nodes = data.num_nodes
         if num_nodes is not None:
             item = torch.full((num_nodes, ), i, dtype=torch.long)
             batch.batch.append(item)
     batch.batch = torch.cat(batch.batch, dim=0)
     return batch
Exemplo n.º 4
0
    def forward(self, data):
        batch_obj = Batch()
        x, pos, batch = data.x, data.pos, data.batch
        if self._precompute_multi_scale:
            idx = getattr(data, "idx_{}".format(self._index), None)
        else:
            idx = self.sampler(pos, batch)
            batch_obj.idx = idx

        ms_x = []
        for scale_idx in range(self.neighbour_finder.num_scales):
            if self._precompute_multi_scale:
                edge_index = getattr(
                    data, "edge_index_{}_{}".format(self._index, scale_idx),
                    None)
            else:
                row, col = self.neighbour_finder(
                    pos,
                    pos[idx],
                    batch_x=batch,
                    batch_y=batch[idx],
                    scale_idx=scale_idx,
                )
                edge_index = torch.stack([col, row], dim=0)

            ms_x.append(self.conv(x, (pos, pos[idx]), edge_index, batch))

        batch_obj.x = torch.cat(ms_x, -1)
        batch_obj.pos = pos[idx]
        batch_obj.batch = batch[idx]
        copy_from_to(data, batch_obj)
        return batch_obj
    def forward(self, data):
        batch_obj = Batch()
        x, pos, batch = data.x, data.pos, data.batch
        idx_sampler = self.sampler(pos=pos, x=x, batch=batch)

        idx_neighbour, _ = self.neighbour_finder(pos,
                                                 pos,
                                                 batch_x=batch,
                                                 batch_y=batch)

        shadow_x = torch.full((1, ) + x.shape[1:],
                              self.shadow_features_fill).to(x.device)
        shadow_points = torch.full((1, ) + pos.shape[1:],
                                   self.shadow_points_fill_).to(x.device)

        x = torch.cat([x, shadow_x], dim=0)
        pos = torch.cat([pos, shadow_points], dim=0)

        x_neighbour = x[idx_neighbour]
        pos_centered_neighbour = pos[idx_neighbour] - pos[:-1].unsqueeze(
            1)  # Centered the points

        batch_obj.x = self.conv(x, pos, x_neighbour, pos_centered_neighbour,
                                idx_neighbour, idx_sampler)

        batch_obj.pos = pos[idx_sampler]
        batch_obj.batch = batch[idx_sampler]
        copy_from_to(data, batch_obj)
        return batch_obj
 def forward(self, data, **kwargs):
     batch_obj = Batch()
     x, pos, batch = data.x, data.pos, data.batch
     x = self.nn(torch.cat([x, pos], dim=1))
     x = self.pool(x, batch)
     batch_obj.x = x
     batch_obj.pos = pos.new_zeros((x.size(0), 3))
     batch_obj.batch = torch.arange(x.size(0), device=batch.device)
     copy_from_to(data, batch_obj)
     return batch_obj
Exemplo n.º 7
0
    def forward(self, data, **kwargs):
        batch_obj = Batch()
        x, pos, batch = data.x, data.pos, data.batch
        idx = self.sampler(pos, batch)
        row, col = self.neighbour_finder(pos, pos[idx], batch_x=batch, batch_y=batch[idx])
        edge_index = torch.stack([col, row], dim=0)
        batch_obj.idx = idx
        batch_obj.edge_index = edge_index

        batch_obj.x = self.conv(x, (pos[idx], pos), edge_index, batch)

        batch_obj.pos = pos[idx]
        batch_obj.batch = batch[idx]
        copy_from_to(data, batch_obj)
        return batch_obj
Exemplo n.º 8
0
    def forward(self, data, **kwargs):
        batch_obj = Batch()
        x, pos, batch = data.x, data.pos, data.batch
        idx = self.sampler(pos, batch)
        batch_obj.idx = idx

        ms_x = []
        for scale_idx in range(self.neighbour_finder.num_scales):
            row, col = self.neighbour_finder(pos, pos[idx], batch_x=batch, batch_y=batch[idx], scale_idx=scale_idx,)
            edge_index = torch.stack([col, row], dim=0)

            ms_x.append(self.conv(x, (pos, pos[idx]), edge_index, batch))

        batch_obj.x = torch.cat(ms_x, -1)
        batch_obj.pos = pos[idx]
        batch_obj.batch = batch[idx]
        copy_from_to(data, batch_obj)
        return batch_obj
Exemplo n.º 9
0
 def forward(self, data, **kwargs):
     batch_obj = Batch()
     x = data.x  # (N, indim)
     shortcut = x  # (N, indim)
     x = self.features_downsample_nn(x)  # (N, outdim//4)
     # if this is an identity resnet block, idx will be None
     data = self.convs(data)  # (N', convdim)
     x = data.x
     idx = data.idx
     x = self.features_upsample_nn(x)  # (N', outdim)
     if idx is not None:
         shortcut = shortcut[idx]  # (N', indim)
     shortcut = self.shortcut_feature_resize_nn(shortcut)  # (N', outdim)
     x = shortcut + x
     batch_obj.x = x
     batch_obj.pos = data.pos
     batch_obj.batch = data.batch
     copy_from_to(data, batch_obj)
     return batch_obj
Exemplo n.º 10
0
    def forward(self, data):
        batch_obj = Batch()
        x, pos, batch = data.x, data.pos, data.batch
        if self._precompute_multi_scale:
            idx = getattr(data, "index_{}".format(self._index), None)
            edge_index = getattr(data, "edge_index_{}".format(self._index),
                                 None)
        else:
            idx = self.sampler(pos, batch)
            row, col = self.neighbour_finder(pos,
                                             pos[idx],
                                             batch_x=batch,
                                             batch_y=batch[idx])
            edge_index = torch.stack([col, row], dim=0)
            batch_obj.idx = idx
            batch_obj.edge_index = edge_index

        batch_obj.x = self.conv(x, (pos, pos[idx]), edge_index, batch)

        batch_obj.pos = pos[idx]
        batch_obj.batch = batch[idx]
        copy_from_to(data, batch_obj)
        return batch_obj
Exemplo n.º 11
0
def data_loader(database, latent_graph_dict, latent_graph_list, batch_size=4, shuffle=True):
    # -- This loader creates Databatches from a Database, run on either the training or validation set!

    # Establish empty lists
    graph_list_p1 = []
    graph_list_p2 = []
    graph_list_pm = []
    target_list = []
    Temperature_list = []

    # Iterate through database and construct latent graphs of the molecules, as well as the target tensors
    for itm in range(len(database)):
        graph_p1 = latent_graph_list[latent_graph_dict[database["Molecule1"].loc[itm]]]
        graph_p2 = latent_graph_list[latent_graph_dict[database["Molecule2"].loc[itm]]]
        graph_pm = latent_graph_list[latent_graph_dict[database["Membrane"].loc[itm]]]

        graph_p1.y = torch.tensor([float(database["molp1"].loc[itm])])
        graph_p2.y = torch.tensor([float(database["molp2"].loc[itm])])
        graph_pm.y = torch.tensor([float(database["molpm"].loc[itm])])

        graph_list_p1.append(graph_p1)
        graph_list_p2.append(graph_p2)
        graph_list_pm.append(graph_pm)

        target = torch.tensor([database[str(i)].loc[itm]*10 for i in range(2, 17)], dtype=torch.float)
        target_list.append(target)

        Temperature_list.append(database["Temperature"].loc[itm])

    # Generate a shuffled integer list that then produces distinct batches
    idxs = torch.randperm(len(database)).tolist()
    if shuffle==False:
        idxs = sorted(idxs)

    # Generate the batches
    batch_p1 = []
    batch_p2 = []
    batch_pm = []
    batch_target = []
    batch_Temperature = []

    for b in range(len(database))[::batch_size]:
        # Creating empty Batch objects
        B_p1 = Batch()
        B_p2 = Batch()
        B_pm = Batch()

        # Creating empty lists that will be concatenated later (advanced minibatching)
        stack_x_p1, stack_x_p2, stack_x_pm = [], [], []
        stack_ei_p1, stack_ei_p2, stack_ei_pm = [], [], []
        stack_ea_p1, stack_ea_p2, stack_ea_pm = [], [], []
        stack_y_p1, stack_y_p2, stack_y_pm = [], [], []
        stack_btc_p1, stack_btc_p2, stack_btc_pm = [], [], []

        stack_target = []
        stack_Temperature = []

        # Creating the batch shifts, that shift the edge indices
        b_shift_p1 = 0
        b_shift_p2 = 0
        b_shift_pm = 0

        for i in range(batch_size):
            if (b+i) < len(database):
                stack_x_p1.append(graph_list_p1[idxs[b + i]].x)
                stack_x_p2.append(graph_list_p2[idxs[b + i]].x)
                stack_x_pm.append(graph_list_pm[idxs[b + i]].x)

                stack_ei_p1.append(graph_list_p1[idxs[b + i]].edge_index+b_shift_p1)
                stack_ei_p2.append(graph_list_p2[idxs[b + i]].edge_index+b_shift_p2)
                stack_ei_pm.append(graph_list_pm[idxs[b + i]].edge_index+b_shift_pm)

                # Updating the shifts
                b_shift_p1 += graph_list_p1[idxs[b + i]].x.size()[0]
                b_shift_p2 += graph_list_p2[idxs[b + i]].x.size()[0]
                b_shift_pm += graph_list_pm[idxs[b + i]].x.size()[0]

                stack_ea_p1.append(graph_list_p1[idxs[b + i]].edge_attr)
                stack_ea_p2.append(graph_list_p2[idxs[b + i]].edge_attr)
                stack_ea_pm.append(graph_list_pm[idxs[b + i]].edge_attr)

                stack_y_p1.append(graph_list_p1[idxs[b + i]].y)
                stack_y_p2.append(graph_list_p2[idxs[b + i]].y)
                stack_y_pm.append(graph_list_pm[idxs[b + i]].y)

                stack_btc_p1.append(
                    torch.full([graph_list_p1[idxs[b + i]].x.size()[0]], fill_value=int(i), dtype=torch.long))
                # FILL VALUE IS PROBABLY JUST i! (OR NOT)
                stack_btc_p2.append(
                    torch.full([graph_list_p2[idxs[b + i]].x.size()[0]], fill_value=int(i), dtype=torch.long))
                stack_btc_pm.append(
                    torch.full([graph_list_pm[idxs[b + i]].x.size()[0]], fill_value=int(i), dtype=torch.long))

                stack_target.append(target_list[idxs[b + i]])
                stack_Temperature.append(Temperature_list[int(b+i)])

        # print(stack_y_p1)
        # print(stack_x_p1)

        B_p1.edge_index=torch.cat(stack_ei_p1,dim=1)
        B_p1.x=torch.cat(stack_x_p1,dim=0)
        B_p1.edge_attr=torch.cat(stack_ea_p1,dim=0)
        B_p1.y=torch.cat(stack_y_p1,dim=0)
        B_p1.batch=torch.cat(stack_btc_p1,dim=0)

        B_p2.edge_index = torch.cat(stack_ei_p2, dim=1)
        B_p2.x = torch.cat(stack_x_p2, dim=0)
        B_p2.edge_attr = torch.cat(stack_ea_p2, dim=0)
        B_p2.y = torch.cat(stack_y_p2, dim=0)
        B_p2.batch = torch.cat(stack_btc_p2, dim=0)

        B_pm.edge_index = torch.cat(stack_ei_pm, dim=1)
        B_pm.x = torch.cat(stack_x_pm, dim=0)
        B_pm.edge_attr = torch.cat(stack_ea_pm, dim=0)
        B_pm.y = torch.cat(stack_y_pm, dim=0)
        B_pm.batch = torch.cat(stack_btc_pm, dim=0)

        B_target = torch.stack(stack_target, dim=0)

        B_Temperature = torch.Tensor(stack_Temperature)

        # Appending batches
        batch_p1.append(B_p1)
        batch_p2.append(B_p2)
        batch_pm.append(B_pm)
        batch_target.append(B_target)
        batch_Temperature.append(B_Temperature)

    # Return
    return [batch_p1, batch_p2, batch_pm, batch_target, batch_Temperature]
Exemplo n.º 12
0
def collate(data_list):
    keys = data_list[0].keys
    assert 'batch' not in keys

    batch = Batch()
    for key in keys:
        batch[key] = []
    batch.batch = []
    if 'edge_index_2' in keys:
        batch.batch_2 = []
    if 'edge_index_3' in keys:
        batch.batch_3 = []

    keys.remove('edge_index')
    props = [
        'edge_index_2', 'assignment_2', 'edge_index_3', 'assignment_3',
        'assignment_2to3'
    ]
    keys = [x for x in keys if x not in props]

    cumsum_1 = N_1 = cumsum_2 = N_2 = cumsum_3 = N_3 = 0

    for i, data in enumerate(data_list):
        for key in keys:
            batch[key].append(data[key])

        N_1 = data.num_nodes
        batch.edge_index.append(data.edge_index + cumsum_1)
        batch.batch.append(torch.full((N_1, ), i, dtype=torch.long))

        if 'edge_index_2' in data:
            N_2 = data.assignment_2[1].max().item() + 1
            batch.edge_index_2.append(data.edge_index_2 + cumsum_2)
            batch.assignment_2.append(data.assignment_2 +
                                      torch.tensor([[cumsum_1], [cumsum_2]]))
            batch.batch_2.append(torch.full((N_2, ), i, dtype=torch.long))

        if 'edge_index_3' in data:
            N_3 = data.assignment_3[1].max().item() + 1
            batch.edge_index_3.append(data.edge_index_3 + cumsum_3)
            batch.assignment_3.append(data.assignment_3 +
                                      torch.tensor([[cumsum_1], [cumsum_3]]))
            batch.batch_3.append(torch.full((N_3, ), i, dtype=torch.long))

        if 'assignment_2to3' in data:
            assert 'edge_index_2' in data and 'edge_index_3' in data
            batch.assignment_2to3.append(
                data.assignment_2to3 + torch.tensor([[cumsum_2], [cumsum_3]]))

        cumsum_1 += N_1
        cumsum_2 += N_2
        cumsum_3 += N_3

    keys = [x for x in batch.keys if x not in ['batch', 'batch_2', 'batch_3']]
    for key in keys:
        batch[key] = torch.cat(batch[key], dim=data_list[0].cat_dim(key))

    batch.batch = torch.cat(batch.batch, dim=-1)

    if 'batch_2' in batch:
        batch.batch_2 = torch.cat(batch.batch_2, dim=-1)

    if 'batch_3' in batch:
        batch.batch_3 = torch.cat(batch.batch_3, dim=-1)

    return batch.contiguous()
Exemplo n.º 13
0
    def forward(self, plist):
        # Here we instantiate the three molecular graphs for the first level GNN
        p1, p2, pm, Temperature = plist
        x_p1, ei_p1, ea_p1, u_p1, btc_p1 = p1.x, p1.edge_index, p1.edge_attr, p1.y, p1.batch
        x_p2, ei_p2, ea_p2, u_p2, btc_p2 = p2.x, p2.edge_index, p2.edge_attr, p2.y, p2.batch
        x_pm, ei_pm, ea_pm, u_pm, btc_pm = pm.x, pm.edge_index, pm.edge_attr, pm.y, pm.batch

        # Embed the node and edge features
        enc_x_p1 = self.encoding_node_1(x_p1)
        enc_x_p2 = self.encoding_node_1(x_p2)
        enc_x_pm = self.encoding_node_1(x_pm)

        enc_ea_p1 = self.encoding_edge_1(ea_p1)
        enc_ea_p2 = self.encoding_edge_1(ea_p2)
        enc_ea_pm = self.encoding_edge_1(ea_pm)

        #Create the empty molecular graphs for feature extraction, graph level one
        u1 = torch.full(size=(self.batch_size, NO_GRAPH_FEATURES_ONE),
                        fill_value=0.1,
                        dtype=torch.float).to(device)
        u2 = torch.full(size=(self.batch_size, NO_GRAPH_FEATURES_ONE),
                        fill_value=0.1,
                        dtype=torch.float).to(device)
        um = torch.full(size=(self.batch_size, NO_GRAPH_FEATURES_ONE),
                        fill_value=0.1,
                        dtype=torch.float).to(device)

        # Now the first level rounds of message passing are performed, molecular features are constructed
        for _ in range(self.no_mp_one):
            enc_x_p1, enc_ea_p1, u1 = self.meta1(x=enc_x_p1,
                                                 edge_index=ei_p1,
                                                 edge_attr=enc_ea_p1,
                                                 u=u1,
                                                 batch=btc_p1)
            enc_x_p2, enc_ea_p2, u2 = self.meta1(x=enc_x_p2,
                                                 edge_index=ei_p2,
                                                 edge_attr=enc_ea_p2,
                                                 u=u2,
                                                 batch=btc_p2)
            enc_x_pm, enc_ea_pm, um = self.meta1(x=enc_x_pm,
                                                 edge_index=ei_pm,
                                                 edge_attr=enc_ea_pm,
                                                 u=um,
                                                 batch=btc_pm)

        # --- GRAPH GENERATION SECOND LEVEL

        #Encode the nodes second level
        u1 = self.encoding_node_2(u1)
        u2 = self.encoding_node_2(u2)
        um = self.encoding_node_2(um)

        # Instantiate new Batch object for second level graph
        nu_Batch = Batch()

        # Create edge indices for the second level (no connection between p1 and p2
        nu_ei = []
        temp_ei = torch.tensor([[0, 2, 1, 2], [2, 0, 2, 1]], dtype=torch.long)
        for b in range(self.batch_size):
            nu_ei.append(
                temp_ei + b * 3
            )  # +3 because this is the number of nodes in the second stage graph
        nu_Batch.edge_index = torch.cat(nu_ei, dim=1)

        # Create the edge features for the second level graphs from the first level "u"s
        p1_div_pm = u_p1 / u_pm
        p2_div_pm = u_p2 / u_pm
        #p1_div_p2 = u_p1 / (u_p2-1)

        # Concatenate the temperature and molecular percentages
        concat_T_p1pm = torch.transpose(torch.stack([Temperature, p1_div_pm]),
                                        0, 1)
        concat_T_p2pm = torch.transpose(torch.stack([Temperature, p2_div_pm]),
                                        0, 1)

        # Encode the edge features
        concat_T_p1pm = self.encoding_edge_2(concat_T_p1pm)
        concat_T_p2pm = self.encoding_edge_2(concat_T_p2pm)

        nu_ea = []
        for b in range(self.batch_size):
            temp_ea = []
            # Appending twice because of bidirectional edges
            temp_ea.append(concat_T_p1pm[b])
            temp_ea.append(concat_T_p1pm[b])
            temp_ea.append(concat_T_p2pm[b])
            temp_ea.append(concat_T_p2pm[b])
            nu_ea.append(torch.stack(temp_ea, dim=0))
        nu_Batch.edge_attr = torch.cat(nu_ea, dim=0)

        # Create new nodes in the batch
        nu_x = []
        for b in range(self.batch_size):
            temp_x = []
            temp_x.append(u1[b])
            temp_x.append(u2[b])
            temp_x.append(um[b])
            nu_x.append(torch.stack(temp_x, dim=0))
        nu_Batch.x = torch.cat(nu_x, dim=0)

        # Create new graph level target
        gamma = torch.full(size=(self.batch_size, NO_GRAPH_FEATURES_TWO),
                           fill_value=0.1,
                           dtype=torch.float)
        nu_Batch.y = gamma
        # Create new batch
        nu_btc = []
        for b in range(self.batch_size):
            nu_btc.append(
                torch.full(size=[3], fill_value=(int(b)), dtype=torch.long))
        nu_Batch.batch = torch.cat(nu_btc, dim=0)

        nu_Batch = nu_Batch.to(device)

        # --- MESSAGE PASSING LVL 2
        for _ in range(self.no_mp_two):
            nu_Batch.x, nu_Batch.edge_attr, nu_Batch.y = self.meta3(
                x=nu_Batch.x,
                edge_index=nu_Batch.edge_index,
                edge_attr=nu_Batch.edge_attr,
                u=nu_Batch.y,
                batch=nu_Batch.batch)

        c = self.mlp_last(nu_Batch.y)

        return c
Exemplo n.º 14
0
    def from_data_list(data_list, follow_batch=[]):
        r"""Constructs a batch object from a python list holding
        :class:`torch_geometric.data.Data` objects.
        The assignment vector :obj:`batch` is created on the fly.
        Additionally, creates assignment batch vectors for each key in
        :obj:`follow_batch`."""

        keys = [set(data.keys) for data in data_list]
        keys = set.union(*keys)
        keys.remove('depth_count')
        keys = list(keys)
        depth = max(data.depth_count.shape[0] for data in data_list)
        assert 'batch' not in keys

        batch = Batch()
        batch.__data_class__ = data_list[0].__class__
        batch.__slices__ = {key: [0] for key in keys}

        for key in keys:
            batch[key] = []

        for key in follow_batch:
            batch['{}_batch'.format(key)] = []

        cumsum = {i: 0 for i in range(depth)}
        depth_count = th.zeros((depth, ), dtype=th.long)
        batch.batch = []
        for i, data in enumerate(data_list):
            edges = data['edge_index']
            for d in range(1, depth):
                mask = data.depth_mask == d
                edges[mask] += cumsum[d - 1]
                cumsum[d - 1] += data.depth_count[d - 1].item()
            batch['edge_index'].append(edges)
            depth_count += data['depth_count']

            for key in data.keys:
                if key == 'edge_index' or key == 'depth_count':
                    continue
                item = data[key]
                batch[key].append(item)

            num_nodes = data.num_nodes
            if num_nodes is not None:
                item = torch.full((num_nodes, ), i, dtype=torch.long)
                batch.batch.append(item)

        if num_nodes is None:
            batch.batch = None

        for key in batch.keys:
            item = batch[key][0]
            if torch.is_tensor(item):
                batch[key] = torch.cat(batch[key],
                                       dim=data_list[0].__cat_dim__(key, item))
            elif isinstance(item, int) or isinstance(item, float):
                batch[key] = torch.tensor(batch[key])

        batch.depth_count = depth_count
        if torch_geometric.is_debug_enabled():
            batch.debug()

        return batch.contiguous()
Exemplo n.º 15
0
def from_data_list(data_list, follow_batch=[]):
    r"""Constructs a batch object from a python list holding
    :class:`torch_geometric.data.Data` objects.
    The assignment vector :obj:`batch` is created on the fly.
    Additionally, creates assignment batch vectors for each key in
    :obj:`follow_batch`."""

    keys = [set(data.keys) for data in data_list]
    keys = list(set.union(*keys))
    assert 'batch' not in keys

    batch = Batch()

    for key in keys:
        batch[key] = []

    for key in follow_batch:
        batch['{}_batch'.format(key)] = []

    cumsum = {}
    batch.batch = []
    for i, data in enumerate(data_list):
        for key in data.keys:
            item = data[key] + cumsum.get(key, 0)
            if key in cumsum:
                if key == 'edge_index':
                    cumsum[key] += data['x_numeric'].shape[0]
                
                if key == 'subgraphs_nodes':
                    cumsum[key] += data['x_numeric'].shape[0]
                if key == 'subgraphs_edges':
                    cumsum[key] += data['edge_attr_numeric'].shape[0]
                if key == 'subgraphs_connectivity':
                    cumsum[key] += data['subgraphs_nodes'].shape[0]
                    
                if key == 'subgraphs_nodes_path_batch':
                    cumsum[key] += data['subgraphs_nodes_path_batch'].max() + 1
                if key == 'subgraphs_edges_path_batch':
                    cumsum[key] += data['subgraphs_edges_path_batch'].max() + 1
                if key == 'subgraphs_global_path_batch':
                    cumsum[key] += 1
                    
                if key == 'cycles_id':
                    if data['cycles_id'].shape[0] > 0:
                        cumsum[key] += data['cycles_id'].max() + 1
                        
                if key == 'cycles_edge_index':
                    cumsum[key] += data['edge_attr_numeric'].shape[0]
                    
                if key == 'edges_connectivity_ids':
                    cumsum[key] += data['edge_attr_numeric'].shape[0]
                    
            else:
                if key == 'edge_index':
                    cumsum[key] = data['x_numeric'].shape[0]
                
                if key == 'subgraphs_nodes':
                    cumsum[key] = data['x_numeric'].shape[0]
                if key == 'subgraphs_edges':
                    cumsum[key] = data['edge_attr_numeric'].shape[0]
                if key == 'subgraphs_connectivity':
                    cumsum[key] = data['subgraphs_nodes'].shape[0]
                
                
                if key == 'subgraphs_nodes_path_batch':
                    cumsum[key] = data['subgraphs_nodes_path_batch'].max() + 1
                if key == 'subgraphs_edges_path_batch':
                    cumsum[key] = data['subgraphs_edges_path_batch'].max() + 1
                if key == 'subgraphs_global_path_batch':
                    cumsum[key] = 1
                    
                    
                if key == 'cycles_id':
                    if data['cycles_id'].shape[0] > 0:
                        cumsum[key] = data['cycles_id'].max() + 1
                        
                if key == 'cycles_edge_index':
                    cumsum[key] = data['edge_attr_numeric'].shape[0]
                    
                if key == 'edges_connectivity_ids':
                    cumsum[key] = data['edge_attr_numeric'].shape[0]


            batch[key].append(item)
            
        for key in follow_batch:
            size = data[key].size(data.__cat_dim__(key, data[key]))
            item = torch.full((size, ), i, dtype=torch.long)
            batch['{}_batch'.format(key)].append(item)

        num_nodes = data.num_nodes
        if num_nodes is not None:
            item = torch.full((num_nodes, ), i, dtype=torch.long)
            batch.batch.append(item)

    if num_nodes is None:
        batch.batch = None

    for key in batch.keys:
        item = batch[key][0]
        if torch.is_tensor(item):
            if key not in ['subgraphs_connectivity', 'edges_connectivity_ids']:
                batch[key] = torch.cat(
                    batch[key], dim=data_list[0].__cat_dim__(key, item))
            else:
                batch[key] = torch.cat(
                    batch[key], dim=-1)
        elif isinstance(item, int) or isinstance(item, float):
            batch[key] = torch.tensor(batch[key])
        else:
            raise ValueError('Unsupported attribute type.')

    # Copy custom data functions to batch (does not work yet):
    # if data_list.__class__ != Data:
    #     org_funcs = set(Data.__dict__.keys())
    #     funcs = set(data_list[0].__class__.__dict__.keys())
    #     batch.__custom_funcs__ = funcs.difference(org_funcs)
    #     for func in funcs.difference(org_funcs):
    #         setattr(batch, func, getattr(data_list[0], func))

    if torch_geometric.is_debug_enabled():
        batch.debug()

    return batch.contiguous()
Exemplo n.º 16
0
    def from_data_list(data_list, follow_batch=[]):
        r"""Constructs a batch object from a python list holding
        :class:`torch_geometric.data.Data` objects.
        The assignment vector :obj:`batch` is created on the fly.
        Additionally, creates assignment batch vectors for each key in
        :obj:`follow_batch`."""
        keys = [set(data.keys) for data in data_list]
        keys = list(set.union(*keys))
        assert 'batch' not in keys
        batch = Batch()
        batch.__data_class__ = data_list[0].__class__
        batch.__slices__ = {key: [0] for key in keys}
        for key in keys:
            batch[key] = []
        for key in follow_batch:
            batch['{}_batch'.format(key)] = []
        cumsum = {key: 0 for key in keys}
        batch.batch = []
        for i, data in enumerate(data_list):
            for key in data.keys:
                # logger.info(f"key={key}")
                item = data[key]
                if torch.is_tensor(item) and item.dtype != torch.bool:
                    item = item + cumsum[key]
                if torch.is_tensor(item):
                    size = item.size(data.__cat_dim__(key, data[key]))
                else:
                    size = 1
                batch.__slices__[key].append(size + batch.__slices__[key][-1])
                cumsum[key] = cumsum[key] + data.__inc__(key, item)
                batch[key].append(item)

                if key in follow_batch:
                    item = torch.full((size,), i, dtype=torch.long)
                    batch['{}_batch'.format(key)].append(item)

            num_nodes = data.num_nodes
            if num_nodes is not None:
                item = torch.full((num_nodes,), i, dtype=torch.long)
                batch.batch.append(item)

        if num_nodes is None:
            batch.batch = None

        for key in batch.keys:
            item = batch[key][0]
            logger.debug(f"key = {key}")
            if torch.is_tensor(item):
                logger.debug(f"batch[{key}]")
                logger.debug(f"item.shape = {item.shape}")
                elem = data_list[0]     # type(elem) = Data or ClevrData
                dim_ = elem.__cat_dim__(key, item)      # basically, which dim we want to concat
                batch[key] = torch.cat(batch[key], dim=dim_)
                # batch[key] = torch.cat(batch[key],
                #                        dim=data_list[0].__cat_dim__(key, item))
            elif isinstance(item, int) or isinstance(item, float):
                batch[key] = torch.tensor(batch[key])
        if torch_geometric.is_debug_enabled():
            batch.debug()

        return batch.contiguous()