def test_num_batches(self): data = Data(pos=torch.randn((2, 3, 3))) self.assertEqual( MockBaseDataset.get_num_samples(data, ConvolutionFormat.DENSE.value), 2) data = Data(pos=torch.randn((3, 3)), batch=torch.tensor([0, 1, 2])) self.assertEqual( MockBaseDataset.get_num_samples( data, ConvolutionFormat.PARTIAL_DENSE.value), 3)
def test_get_sample(self): data = Data(pos=torch.randn((2, 3, 3))) torch.testing.assert_allclose( MockBaseDataset.get_sample(data, "pos", 1, ConvolutionFormat.DENSE.value), data.pos[1] ) data = Data(pos=torch.randn((3, 3)), batch=torch.tensor([0, 1, 2])) torch.testing.assert_allclose( MockBaseDataset.get_sample(data, "pos", 1, ConvolutionFormat.PARTIAL_DENSE.value), data.pos[1] )
def load_data(filename=DATA_PATH, directed=False): raw = json.load(open(filename)) features = torch.FloatTensor(np.array(raw['features'])) labels = torch.LongTensor(np.array(raw['labels'])) if hasattr(torch, 'BoolTensor'): train_masks = [torch.BoolTensor(tr) for tr in raw['train_masks']] val_masks = [torch.BoolTensor(val) for val in raw['val_masks']] stopping_masks = [torch.BoolTensor(st) for st in raw['stopping_masks']] test_mask = torch.BoolTensor(raw['test_mask']) else: train_masks = [torch.ByteTensor(tr) for tr in raw['train_masks']] val_masks = [torch.ByteTensor(val) for val in raw['val_masks']] stopping_masks = [torch.ByteTensor(st) for st in raw['stopping_masks']] test_mask = torch.ByteTensor(raw['test_mask']) if directed: edges = [[(i, j) for j in js] for i, js in enumerate(raw['links'])] edges = list(itertools.chain(*edges)) else: edges = [[(i, j) for j in js] + [(j, i) for j in js] for i, js in enumerate(raw['links'])] edges = list(set(itertools.chain(*edges))) edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous() data = Data(x=features, edge_index=edge_index, y=labels) data.train_masks = train_masks data.val_masks = val_masks data.stopping_masks = stopping_masks data.test_mask = test_mask return data
def forward(self, data): dense_input = True if isinstance(data, torch.Tensor) else False if dense_input: # Convert to torch_geometric.data.Data type data = data.transpose(1, 2).contiguous() batch_size, N, _ = data.shape # (batch_size, num_points, 3) pos = data.view(batch_size * N, -1) batch = torch.zeros((batch_size, N), device=pos.device, dtype=torch.long) for i in range(batch_size): batch[i] = i batch = batch.view(-1) data = Data() data.pos, data.batch = pos, batch if not hasattr(data, "x"): data.x = None data_in = data.x, data.pos, data.batch sa1_out = self.sa1_module(data_in) sa2_out = self.sa2_module(sa1_out) sa3_out = self.sa3_module(sa2_out) x, pos, batch = sa3_out x = F.relu(self.lin1(x)) x = F.dropout(x, p=self.config["dropout"], training=self.training) x = F.relu(self.lin2(x)) x = F.dropout(x, p=self.config["dropout"], training=self.training) x = self.lin3(x) return x
def __getitem__(self, idx): pos = torch.from_numpy(np.random.normal(0, 1, (self.num_points, 3))) y = torch.from_numpy( np.random.normal(0, 1, (self.num_points, self.output_nc))) x = torch.from_numpy( np.random.normal(0, 1, (self.num_points, self.input_nc))) return Data(x=x, pos=pos, y=y)
def reorder_pyG(self, g): new_index = self.reorder_edge_index(g.edge_index) new_edge_index = g.edge_index[:, new_index] new_edge_weight = g.edge_weight[new_index] new_g = Data(edge_index=new_edge_index, edge_weight=new_edge_weight) return new_g
def _read_file(self, filename): raw = read_txt_array(filename) pos = raw[:, :3] x = raw[:, 3:6] if raw.shape[1] == 7: y = raw[:, 6].type(torch.long) else: y = None return Data(pos=pos, x=x, y=y)
def doc2graph(doc, word2idx=None): """ 2020/8/4 18:30 input Stanza Document : doc output PytorchGeoData : G G = { x: id tensor edge_idx : edges size = (2, l-1) edge_attr: (u, v, edge_type in str) node_attr: text } """ if isinstance(doc, list): #convert to Doc first if is in dict form ([[dict]]) doc = Document(doc) # add root token for each sentences e = [[], []] edge_info = [] node_info = [] prev_token_sum = 0 prev_root_id = 0 cur_root_id = 0 # get original dependency for idx, sent in enumerate(doc.sentences): sent.print_dependencies # node info by index(add root at the beginning of every sentence) cur_root_id = len(node_info) node_info.append("[ROOT]") for token in sent.tokens: node_info.append(token.to_dict()[0]['text']) # edge info by index of u in edge (u,v) for dep in sent.dependencies: id1 = prev_token_sum + int(dep[0].to_dict()["id"]) id2 = prev_token_sum + int(dep[2].to_dict()["id"]) e[0].append(id1) e[1].append(id2) edge_info.append((id1, id2, dep[1])) prev_token_sum += len(sent.tokens) + 1 # add links between sentence roots if (cur_root_id != 0): id1 = prev_root_id id2 = cur_root_id e[0].append(id1) e[1].append(id2) edge_info.append((id1, id2, "bridge")) prev_root_id = cur_root_id # id to embeddings # x = torch.tensor([ for token in node_attr]) # done building edges and nodes if word2idx == None: # print("x is not id now, node info is in node_attr as text") x = torch.tensor(list(range(doc.num_tokens + len(doc.sentences)))) else: x = torch.tensor([word2idx[token] for token in node_info]) e = torch.tensor(e) G = Data(x=x, edge_index=e, edge_attr=edge_info, node_attr=node_info) return G
def get_planetoid(self, dataset='cora'): path = osp.join( '/home/cai.507/Documents/DeepLearning/sparsifier/sparsenet', 'data', dataset) dataset = Planetoid(path, dataset, T.TargetIndegree()) n_edge = dataset.data.edge_index.size(1) g = Data(edge_index=dataset.data.edge_index, edge_weight=torch.ones(n_edge)) assert g.is_directed() == False # g = g.coalesce() return g
def make_data_batch(): # batch_size = 2 pos_num1 = 1000 pos_num2 = 1024 data_batch = Data() # data_batch.x = None data_batch.pos = torch.cat([torch.rand(pos_num1, 3), torch.rand(pos_num2, 3)], dim=0) data_batch.batch = torch.cat([torch.zeros(pos_num1, dtype=torch.long), torch.ones(pos_num2, dtype=torch.long)]) return data_batch
def rm_pyG_edges(self, g, n=1): n_edge = g.edge_index.size(1) // 2 retain_edges = random.sample(range(n_edge), k=n_edge - n) indices = [] for idx in retain_edges: indices.append(2 * idx) indices.append(2 * idx + 1) new_edge_index = g.edge_index[:, indices] new_edge_weight = g.edge_weight[indices] new_g = Data(edge_index=new_edge_index, edge_weight=new_edge_weight) return new_g
def test_add_weights(self): dataset_opt = MockDatasetConfig() setattr(dataset_opt, "dataroot", os.path.join(DIR, "temp_dataset")) mock_base_dataset = MockBaseDataset(dataset_opt) mock_base_dataset.train_dataset.data = Data(y=torch.tensor([1, 1, 1, 0])) mock_base_dataset.add_weights() self.assertGreater(mock_base_dataset.weight_classes[0], mock_base_dataset.weight_classes[1]) mock_base_dataset.add_weights(class_weight_method="log") print(mock_base_dataset.weight_classes) self.assertGreater(mock_base_dataset.weight_classes[0], mock_base_dataset.weight_classes[1])
def forward(self, data): """ data: a batch of input, torch.Tensor or torch_geometric.data.Data type - torch.Tensor: (batch_size, 3, num_points), as common batch input - torch_geometric.data.Data, as torch_geometric batch input: data.x: (batch_size * ~num_points, C), batch nodes/points feature, ~num_points means each sample can have different number of points/nodes data.pos: (batch_size * ~num_points, 3) data.batch: (batch_size * ~num_points,), a column vector of graph/pointcloud idendifiers for all nodes of all graphs/pointclouds in the batch. See pytorch_gemometric documentation for more information """ dense_input = True if isinstance(data, torch.Tensor) else False if dense_input: # Convert to torch_geometric.data.Data type data = data.transpose(1, 2).contiguous() batch_size, N, _ = data.shape # (batch_size, num_points, 3) pos = data.view(batch_size * N, -1) batch = torch.zeros((batch_size, N), device=pos.device, dtype=torch.long) for i in range(batch_size): batch[i] = i batch = batch.view(-1) data = Data() data.pos, data.batch = pos, batch if not hasattr(data, "x"): data.x = None data_in = data.x, data.pos, data.batch sa1_out = self.sa1_module(data_in) sa2_out = self.sa2_module(sa1_out) sa3_out = self.sa3_module(sa2_out) fp3_out = self.fp3_module(sa3_out, sa2_out) fp2_out = self.fp2_module(fp3_out, sa1_out) fp1_out = self.fp1_module(fp2_out, data_in) fp1_out_x, fp1_out_pos, fp1_out_batch = fp1_out x = self.fc2(self.dropout1(self.fc1(fp1_out_x))) x = F.log_softmax(x, dim=-1) if dense_input: return x.view(batch_size, N, self.num_classes) return x, fp1_out_batch
def increase_random_edge_w(self, g, n=1, w=1000): """ randomly increase the weight of edge. change the weight to w """ g = copy.deepcopy(g) n_edge = g.edge_index.size(1) // 2 assert n < n_edge, f'n {n} has to be smaler than {n_edge}.' change_edges = random.sample(range(n_edge), k=n) new_weight = g.edge_weight for idx in change_edges: new_weight[2 * idx] = w new_weight[2 * idx + 1] = w new_index = g.edge_index new_g = Data(edge_index=new_index, edge_weight=new_weight) return new_g
def doc2graph_allennlp(doc: Dict) -> Data: """ input: allen dependecies (Dict) return G = { x: id tensor edge_idx : edges size = (2, l-1) edge_attr: (u, v, edge_type in str) node_attr: text } """ # add root token for each sentences n = len(doc["words"]) e = [list(range(1, n+1)),doc["predicted_heads"]] edge_attr = list(zip(e[0], e[1], doc["predicted_dependencies"])) node_attr = ["[ROOT]"] node_attr.extend(doc["words"]) x = torch.tensor(list(range(n))) e = torch.tensor(e) G = Data(x=x, edge_index=e, edge_attr=edge_attr, node_attr=node_attr) return G
def to_homogeneous(self, node_attrs: Optional[List[str]] = None, edge_attrs: Optional[List[str]] = None, add_node_type: bool = True, add_edge_type: bool = True) -> Data: """Converts a :class:`~torch_geometric.data.HeteroData` object to a homogeneous :class:`~torch_geometric.data.Data` object. By default, all features with same feature dimensionality across different types will be merged into a single representation, unless otherwise specified via the :obj:`node_attrs` and :obj:`edge_attrs` arguments. Furthermore, attributes named :obj:`node_type` and :obj:`edge_type` will be added to the returned :class:`~torch_geometric.data.Data` object, denoting node-level and edge-level vectors holding the node and edge type as integers, respectively. Args: node_attrs (List[str], optional): The node features to combine across all node types. These node features need to be of the same feature dimensionality. If set to :obj:`None`, will automatically determine which node features to combine. (default: :obj:`None`) edge_attrs (List[str], optional): The edge features to combine across all edge types. These edge features need to be of the same feature dimensionality. If set to :obj:`None`, will automatically determine which edge features to combine. (default: :obj:`None`) add_node_type (bool, optional): If set to :obj:`False`, will not add the node-level vector :obj:`node_type` to the returned :class:`~torch_geometric.data.Data` object. (default: :obj:`True`) add_edge_type (bool, optional): If set to :obj:`False`, will not add the edge-level vector :obj:`edge_type` to the returned :class:`~torch_geometric.data.Data` object. (default: :obj:`True`) """ def _consistent_size(stores: List[BaseStorage]) -> List[str]: sizes_dict = defaultdict(list) for store in stores: for key, value in store.items(): if key in ['edge_index', 'adj_t']: continue if isinstance(value, Tensor): dim = self.__cat_dim__(key, value, store) size = value.size()[:dim] + value.size()[dim + 1:] sizes_dict[key].append(tuple(size)) return [ k for k, sizes in sizes_dict.items() if len(sizes) == len(stores) and len(set(sizes)) == 1 ] data = Data(**self._global_store.to_dict()) # Iterate over all node stores and record the slice information: node_slices, cumsum = {}, 0 node_type_names, node_types = [], [] for i, (node_type, store) in enumerate(self._node_store_dict.items()): num_nodes = store.num_nodes node_slices[node_type] = (cumsum, cumsum + num_nodes) node_type_names.append(node_type) cumsum += num_nodes if add_node_type: kwargs = {'dtype': torch.long} node_types.append(torch.full((num_nodes, ), i, **kwargs)) data._node_type_names = node_type_names if len(node_types) > 1: data.node_type = torch.cat(node_types, dim=0) elif len(node_types) == 1: data.node_type = node_types[0] # Combine node attributes into a single tensor: if node_attrs is None: node_attrs = _consistent_size(self.node_stores) for key in node_attrs: values = [store[key] for store in self.node_stores] dim = self.__cat_dim__(key, values[0], self.node_stores[0]) value = torch.cat(values, dim) if len(values) > 1 else values[0] data[key] = value if len([ key for key in node_attrs if (key in {'x', 'pos', 'batch'} or 'node' in key) ]) == 0 and not add_node_type: data.num_nodes = cumsum # Iterate over all edge stores and record the slice information: edge_slices, cumsum = {}, 0 edge_indices, edge_type_names, edge_types = [], [], [] for i, (edge_type, store) in enumerate(self._edge_store_dict.items()): src, _, dst = edge_type num_edges = store.num_edges edge_slices[edge_type] = (cumsum, cumsum + num_edges) edge_type_names.append(edge_type) cumsum += num_edges kwargs = {'dtype': torch.long, 'device': store.edge_index.device} offset = [[node_slices[src][0]], [node_slices[dst][0]]] offset = torch.tensor(offset, **kwargs) edge_indices.append(store.edge_index + offset) if add_edge_type: edge_types.append(torch.full((num_edges, ), i, **kwargs)) data._edge_type_names = edge_type_names if len(edge_indices) > 1: data.edge_index = torch.cat(edge_indices, dim=-1) elif len(edge_indices) == 1: data.edge_index = edge_indices[0] if len(edge_types) > 1: data.edge_type = torch.cat(edge_types, dim=0) elif len(edge_types) == 1: data.edge_type = edge_types[0] # Combine edge attributes into a single tensor: if edge_attrs is None: edge_attrs = _consistent_size(self.edge_stores) for key in edge_attrs: values = [store[key] for store in self.edge_stores] dim = self.__cat_dim__(key, values[0], self.edge_stores[0]) value = torch.cat(values, dim) if len(values) > 1 else values[0] data[key] = value return data
all_features_pad = [ pad(feat, (max_nodes, feat.shape[1])) for feat in all_link_features ] def create_mask(feat, max_nodes): return np.array( [True if i < feat.shape[0] else False for i in range(max_nodes)]) all_masks = [create_mask(feat, max_nodes) for feat in all_link_features] num_features = all_features_pad[0].shape[1] print('max-nodes = ', max_nodes, ', num-features = ', num_features) #step 3: Create dataset object data = [ Data(adj=torch.from_numpy(adj).float(), mask=torch.from_numpy(mask), x=torch.from_numpy(x[:, :num_features]).float(), y=torch.from_numpy(np.array([y])).float()) for adj, mask, x, y in zip(all_link_adj_symmetric_pad, all_masks, all_features_pad, all_rewards) ] random.shuffle(data) n_test = (len(data) + 2) // 3 random.shuffle(data) known_dataset = data[:-n_test] unknown_dataset = data[-n_test:] # random.shuffle(known_dataset) n_val = (len(known_dataset) + 9) // 10 train_dataset = known_dataset[:-n_val] val_dataset = known_dataset[-n_val:] test_dataset = unknown_dataset
all_features_pad = [ pad(feat, (max_nodes, feat.shape[1])) for feat in all_link_features ] def create_mask(feat, max_nodes): return np.array( [True if i < feat.shape[0] else False for i in range(max_nodes)]) all_masks = [create_mask(feat, max_nodes) for feat in all_link_features] #num_channels = all_features_pad[0].shape[1] #step 3: Create dataset object data = [ Data(adj=torch.from_numpy(adj).float(), mask=torch.from_numpy(mask), x=torch.from_numpy(x[:, :num_channels]).float(), y=torch.from_numpy(np.array([y])).float(), std=torch.from_numpy(np.array([std_dict[std]])).float()) for adj, mask, x, y, std in zip(all_link_adj_symmetric_pad, all_masks, all_features_pad, all_rewards, std_dict) ] import random random.shuffle(data) dataset = dataset.shuffle() n = (len(dataset) + 9) // 10 test_dataset = data[:n] val_dataset = data[n:2 * n] train_dataset = data[2 * n:] with open('test_loader',
test_dataset = pickle.load(test_file) train_dataset = pickle.load(train_file) val_dataset = pickle.load(val_file) else: os.makedirs(dataset_dir, exist_ok=True) raw_dataset_path = os.path.join(current_dir, 'data', args.dataset_name + '.csv') all_features, all_link_adj, all_masks, all_rewards \ = load_terminal_design_data(raw_dataset_path, os.path.join(project_dir, 'data/designs/grammar_jan21.dot')) # Create dataset object data = [ Data(adj=torch.from_numpy(adj).float(), mask=torch.from_numpy(mask), x=torch.from_numpy(x).float(), y=torch.from_numpy(np.array([y])).float()) for adj, mask, x, y in zip(all_link_adj, all_masks, all_features, all_rewards) ] random.shuffle(data) n_val = (len(data) + 9) // 10 n_test = (len(data) + 9) // 10 train_dataset = data[:-n_test - n_val] val_dataset = data[-n_test - n_val:-n_test] test_dataset = data[-n_test:] with open(testset_path, 'wb') as test_file, open( valset_path, 'wb') as val_file, open(trainset_path, 'wb') as train_file: pickle.dump(test_dataset, test_file)
def load_one_graph(self, fname, mol): """Loads one graph Args: fname (str): hdf5 file name mol (str): name of the molecule Returns: Data object or None: torch_geometric Data object containing the node features, the internal and external edge features, the target and the xyz coordinates. Return None if features cannot be loaded. """ f5 = h5py.File(fname, 'r') try: grp = f5[mol] except: f5.close() return None # nodes data = () try: for feat in self.node_feature: vals = grp['node_data/'+feat][()] if vals.ndim == 1: vals = vals.reshape(-1, 1) data += (vals,) x = torch.tensor(np.hstack(data), dtype=torch.float) except: print('node attributes not found in the file', self.database[0]) f5.close() return None try: # index ! we have to have all the edges i.e : (i,j) and (j,i) ind = grp['edge_index'][()] ind = np.vstack((ind, np.flip(ind, 1))).T edge_index = torch.tensor( ind, dtype=torch.long).contiguous() # edge feature (same issue than above) data = () if self.edge_feature is not None: for feat in self.edge_feature: vals = grp['edge_data/'+feat][()] if vals.ndim == 1: vals = vals.reshape(-1, 1) data += (vals,) data = np.hstack(data) data = np.vstack((data, data)) data = self.edge_feature_transform(data) edge_attr = torch.tensor( data, dtype=torch.float).contiguous() else: edge_attr = None # internal edges ind = grp['internal_edge_index'][()] ind = np.vstack((ind, np.flip(ind, 1))).T internal_edge_index = torch.tensor( ind, dtype=torch.long).contiguous() # internal edge feature data = () if self.edge_feature is not None: for feat in self.edge_feature: vals = grp['internal_edge_data/'+feat][()] if vals.ndim == 1: vals = vals.reshape(-1, 1) data += (vals,) data = np.hstack(data) data = np.vstack((data, data)) data = self.edge_feature_transform(data) internal_edge_attr = torch.tensor( data, dtype=torch.float).contiguous() else: internal_edge_attr = None except: print('edge features not found in the file', self.database[0]) f5.close() return None # target if self.target is None: y = None else: if grp['score/'+self.target][()] is not None: y = torch.tensor( [grp['score/'+self.target][()]], dtype=torch.float).contiguous() else: y = None # pos pos = torch.tensor(grp['node_data/pos/'] [()], dtype=torch.float).contiguous() # load data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, pos=pos) data.internal_edge_index = internal_edge_index data.internal_edge_attr = internal_edge_attr # mol name data.mol = mol # cluster if 'clustering' in grp.keys(): if self.clustering_method in grp['clustering'].keys(): if ('depth_0' in grp['clustering/{}'.format(self.clustering_method)].keys() and 'depth_1' in grp['clustering/{}'.format( self.clustering_method)].keys() ): data.cluster0 = torch.tensor( grp['clustering/' + self.clustering_method + '/depth_0'][()], dtype=torch.long) data.cluster1 = torch.tensor( grp['clustering/' + self.clustering_method + '/depth_1'][()], dtype=torch.long) else: print('WARNING: no cluster detected') else: print('WARNING: no cluster detected') else: print('WARNING: no cluster detected') f5.close() return data
def to_homogeneous(self, node_attrs: Optional[List[str]] = None, edge_attrs: Optional[List[str]] = None, add_node_type: bool = True, add_edge_type: bool = True) -> Data: """Converts a :class:`~torch_geometric.data.HeteroData` object to a homogeneous :class:`~torch_geometric.data.Data` object. By default, all features with same feature dimensionality across different types will be merged into a single representation, unless otherwise specified via the :obj:`node_attrs` and :obj:`edge_attrs` arguments. Furthermore, attributes named :obj:`node_type` and :obj:`edge_type` will be added to the returned :class:`~torch_geometric.data.Data` object, denoting node-level and edge-level vectors holding the node and edge type as integers, respectively. Args: node_attrs (List[str], optional): The node features to combine across all node types. These node features need to be of the same feature dimensionality. If set to :obj:`None`, will automatically determine which node features to combine. (default: :obj:`None`) edge_attrs (List[str], optional): The edge features to combine across all edge types. These edge features need to be of the same feature dimensionality. If set to :obj:`None`, will automatically determine which edge features to combine. (default: :obj:`None`) add_node_type (bool, optional): If set to :obj:`False`, will not add the node-level vector :obj:`node_type` to the returned :class:`~torch_geometric.data.Data` object. (default: :obj:`True`) add_edge_type (bool, optional): If set to :obj:`False`, will not add the edge-level vector :obj:`edge_type` to the returned :class:`~torch_geometric.data.Data` object. (default: :obj:`True`) """ def _consistent_size(stores: List[BaseStorage]) -> List[str]: sizes_dict = defaultdict(list) for store in stores: for key, value in store.items(): if key in ['edge_index', 'adj_t']: continue if isinstance(value, Tensor): dim = self.__cat_dim__(key, value, store) size = value.size()[:dim] + value.size()[dim + 1:] sizes_dict[key].append(tuple(size)) return [ k for k, sizes in sizes_dict.items() if len(sizes) == len(stores) and len(set(sizes)) == 1 ] edge_index, node_slices, edge_slices = to_homogeneous_edge_index(self) device = edge_index.device if edge_index is not None else None data = Data(**self._global_store.to_dict()) if edge_index is not None: data.edge_index = edge_index data._node_type_names = list(node_slices.keys()) data._edge_type_names = list(edge_slices.keys()) # Combine node attributes into a single tensor: if node_attrs is None: node_attrs = _consistent_size(self.node_stores) for key in node_attrs: values = [store[key] for store in self.node_stores] dim = self.__cat_dim__(key, values[0], self.node_stores[0]) value = torch.cat(values, dim) if len(values) > 1 else values[0] data[key] = value if not data.can_infer_num_nodes: data.num_nodes = list(node_slices.values())[-1][1] # Combine edge attributes into a single tensor: if edge_attrs is None: edge_attrs = _consistent_size(self.edge_stores) for key in edge_attrs: values = [store[key] for store in self.edge_stores] dim = self.__cat_dim__(key, values[0], self.edge_stores[0]) value = torch.cat(values, dim) if len(values) > 1 else values[0] data[key] = value if add_node_type: sizes = [offset[1] - offset[0] for offset in node_slices.values()] sizes = torch.tensor(sizes, dtype=torch.long, device=device) node_type = torch.arange(len(sizes), device=device) data.node_type = node_type.repeat_interleave(sizes) if add_edge_type and edge_index is not None: sizes = [offset[1] - offset[0] for offset in edge_slices.values()] sizes = torch.tensor(sizes, dtype=torch.long, device=device) edge_type = torch.arange(len(sizes), device=device) data.edge_type = edge_type.repeat_interleave(sizes) return data