def test_batching_of_batches(): data = Data(x=torch.randn(2, 16)) batch = Batch.from_data_list([data, data]) batch = Batch.from_data_list([batch, batch]) assert len(batch) == 2 assert batch.x[0:2].tolist() == data.x.tolist() assert batch.x[2:4].tolist() == data.x.tolist() assert batch.x[4:6].tolist() == data.x.tolist() assert batch.x[6:8].tolist() == data.x.tolist() assert batch.batch.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
def from_data_list(data_list): r""" from a list of torch_points3d.datasets.registation.pair.Pair objects, create a batch Warning : follow_batch is not here yet... """ assert isinstance(data_list[0], Pair) data_list_s, data_list_t = list(map(list, zip(*[data.to_data() for data in data_list]))) batch_s = Batch.from_data_list(data_list_s) batch_t = Batch.from_data_list(data_list_t) return PairBatch.make_pair(batch_s, batch_t).contiguous()
def sample(self, batch_size): max_mem = min(self.mem_cntr, self.mem_size) batch = np.random.choice(max_mem, batch_size, replace=False) graphs_pre_batch = Batch.from_data_list( [self.graphs_pre[b] for b in batch]) graphs_later_batch = Batch.from_data_list( [self.graphs_later[b] for b in batch]) actions_batch = T.tensor([self.actions[b] for b in batch]) rewards_batch = T.tensor([self.rewards[b] for b in batch]) return graphs_pre_batch, graphs_later_batch, actions_batch, rewards_batch
def score(self): """ Scoring. """ print("\n\nModel evaluation.\n") self.model.eval() scores = np.empty( (len(self.testing_graphs), len(self.training_graphs))) ground_truth = np.empty( (len(self.testing_graphs), len(self.training_graphs))) prediction_mat = np.empty( (len(self.testing_graphs), len(self.training_graphs))) rho_list = [] tau_list = [] prec_at_10_list = [] prec_at_20_list = [] t = tqdm(total=len(self.testing_graphs) * len(self.training_graphs)) for i, g in enumerate(self.testing_graphs): source_batch = Batch.from_data_list([g] * len(self.training_graphs)) target_batch = Batch.from_data_list(self.training_graphs) data = self.transform((source_batch, target_batch)) target = data["target"] ground_truth[i] = target prediction = self.model(data) prediction_mat[i] = prediction.detach().numpy() scores[i] = (F.mse_loss(prediction, target, reduction="none").detach().numpy()) rho_list.append( calculate_ranking_correlation(spearmanr, prediction_mat[i], ground_truth[i])) tau_list.append( calculate_ranking_correlation(kendalltau, prediction_mat[i], ground_truth[i])) prec_at_10_list.append( calculate_prec_at_k(10, prediction_mat[i], ground_truth[i])) prec_at_20_list.append( calculate_prec_at_k(20, prediction_mat[i], ground_truth[i])) t.update(len(self.training_graphs)) self.rho = np.mean(rho_list).item() self.tau = np.mean(tau_list).item() self.prec_at_10 = np.mean(prec_at_10_list).item() self.prec_at_20 = np.mean(prec_at_20_list).item() self.model_error = np.mean(scores).item() self.print_evaluation()
def set_input(self, data, device): self.input = Batch(pos=data.pos, x=data.x, batch=data.batch).to(device) if hasattr(data, "pos_target"): self.input_target = Batch(pos=data.pos_target, x=data.x_target, batch=data.batch_target).to(device) self.match = data.pair_ind.to(torch.long).to(device) self.size_match = data.size_pair_ind.to(torch.long).to(device) else: self.match = data.pair_ind.to(torch.long).to(device) self.size_match = data.size_pair_ind.to(torch.long).to(device)
def keep_human_object_interactions(input_graph: Batch, target_graph: Batch, filename: str, *, human_class: int): subjs = input_graph.object_classes[input_graph.relation_indexes[0]] keep = subjs == human_class input_graph.n_edges = keep.sum().item() input_graph.relation_indexes = input_graph.relation_indexes[:, keep] input_graph.relation_linear_features = input_graph.relation_linear_features[ keep] return input_graph, target_graph
def collate(data_list): batch = Batch() batch.batch = [] for key in data_list[0].keys: batch[key] = default_collate([d[key] for d in data_list]) for i, data in enumerate(data_list): num_nodes = data.num_nodes if num_nodes is not None: item = torch.full((num_nodes, ), i, dtype=torch.long) batch.batch.append(item) batch.batch = torch.cat(batch.batch, dim=0) return batch
def collate_fn_withpad(data_list): ''' Modified based on PyTorch-Geometric's implementation :param data_list: :return: ''' keys = [set(data.keys) for data in data_list] keys = list(set.union(*keys)) assert 'batch' not in keys batch = Batch() for key in keys: batch[key] = [] batch.batch = [] cumsum = 0 for i, data in enumerate(data_list): num_nodes = data.num_nodes batch.batch.append(torch.full((num_nodes, ), i, dtype=torch.long)) for key in data.keys: item = data[key] item = item + cumsum if data.__cumsum__(key, item) else item batch[key].append(item) cumsum += num_nodes for key in keys: item = batch[key][0] if torch.is_tensor(item): if (len(item.shape) == 3): tlens = [x.shape[1] for x in batch[key]] maxtlens = np.max(tlens) to_cat = [] for x in batch[key]: to_cat.append( torch.cat([ x, x.new_zeros(x.shape[0], maxtlens - x.shape[1], x.shape[2]) ], dim=1)) batch[key] = torch.cat(to_cat, dim=0) if 'tlens' not in batch.keys: batch['tlens'] = item.new_tensor(tlens, dtype=torch.long) else: batch[key] = torch.cat(batch[key], dim=data_list[0].__cat_dim__(key, item)) elif isinstance(item, int) or isinstance(item, float): batch[key] = torch.tensor(batch[key]) else: raise ValueError('Unsupported attribute type.') batch.batch = torch.cat(batch.batch, dim=-1) return batch.contiguous()
def forward(self, data, **kwargs): batch_obj = Batch() x, pos, batch = data.x, data.pos, data.batch if pos is not None: x = self.nn(torch.cat([x, pos], dim=1)) x = self.pool(x, batch) batch_obj.x = x if pos is not None: batch_obj.pos = pos.new_zeros((x.size(0), 3)) batch_obj.batch = torch.arange(x.size(0), device=batch.device) copy_from_to(data, batch_obj) return batch_obj
def sample(self, batch_size): max_mem = min(self.mem_cntr, self.mem_size) p = np.array(self.rewards) / np.sum(self.rewards) batch = np.random.choice(self.mem_size, batch_size, replace=False, p=p) graphs_former_batch = Batch.from_data_list( [self.graphs_former[b] for b in batch]) graphs_later_batch = Batch.from_data_list( [self.graphs_later[b] for b in batch]) actions_batch = torch.Tensor([self.actions[b] for b in batch]) rewards_batch = torch.Tensor([self.rewards[b] for b in batch]) done_batch = torch.Tensor([self.done[b] for b in batch]) return graphs_former_batch, graphs_later_batch, actions_batch, rewards_batch, done_batch
def forward(self, batch): r"""Forward computation which computes the raw edge score, normalizes it """ data_list = Batch.to_data_list(batch) data_list_out = [] for data in data_list: new_edge_attr = softmax(data.edge_attr, data.edge_index[0]) data.edge_attr = new_edge_attr data_list_out.append(data) batch = Batch.from_data_list(data_list_out) return batch
def split_vr_batch(relations: Batch) -> List[Data]: # Hack to force torch_geometric to accept our graphs relations.x = relations.object_boxes relations.__slices__["x"] = relations.__slices__["object_boxes"] result = [] for r in relations.to_data_list(): r.x = None r.n_nodes = r.n_nodes.item() r.n_edges = r.n_edges.item() result.append(r) relations.x = None del relations.__slices__["x"] return result
def train(train_loader, model, criterion, optimizer, epoch, args): batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') top5 = AverageMeter('Acc@5', ':6.2f') progress = ProgressMeter(len(train_loader), [batch_time, data_time, losses, top1, top5], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() for i, images in enumerate(train_loader): end = time.time() for im in images: im.edge_attr = None images_cls = Batch.from_data_list(images) im_q = Batch.from_data_list(random_augmentation(images)) im_k = Batch.from_data_list(random_augmentation(images)) data_time.update(time.time() - end) if args.gpu is not None: im_q = im_q.to(args.gpu) #, non_blocking=True) im_k = im_k.to(args.gpu) #, non_blocking=True) images_cls = images_cls.to(args.gpu) output, target, q_cls = model(im_q=im_q, im_k=im_k, image=images_cls) if args.gpu != None: target = target.to(args.gpu) loss = criterion(output, target) acc1, acc5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), len(images)) loss_list.append(loss.item()) top1.update(acc1[0], len(images)) top5.update(acc5[0], len(images)) optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i)
def make_bce_and_rank_targets(input_graph: Batch, target_graph: Batch, filename: str, *, num_classes): """Binary and rank encoding of unique predicates""" unique_predicates = torch.unique(target_graph.predicate_classes, sorted=False) target_graph.predicate_bce = (torch.zeros(num_classes, dtype=torch.float).scatter_( dim=0, index=unique_predicates, value=1.0).view(1, -1)) target_graph.predicate_rank = torch.constant_pad_nd( unique_predicates, pad=(0, num_classes - len(unique_predicates)), value=-1).view(1, -1) return input_graph, target_graph
def forward(self, batch): batch = Batch.to_data_list(batch) batch_size = len(batch) n_chans = batch[0].x.shape[-1] edge_attrs = torch.stack([d.edge_attr.t() for d in batch]) edge_attrs_out = self.conv(edge_attrs) edge_attrs_out = torch.exp(-edge_attrs_out) # put new attributes in graphs for i in range(batch_size): batch[i].edge_attr = edge_attrs_out[i, ...].t() return Batch.from_data_list(batch)
def collate_fn(samples): # print(samples) # filtering none samples = [sample for sample in samples if sample is not None] if samples: # nonempty : tuple or torch_batch if isinstance(samples[0], list): # list : multiple transform num_transforms = len(samples[0]) flatten_list = [_ for sample in samples for _ in sample] # bs * num_transforms data_trsfs_dict = OrderedDict() for trsf_i in range(num_transforms): trsf_data = flatten_list[ trsf_i:: num_transforms] # list or list of tuples(aug, perm) trsf_data = [data for data in trsf_data if data is not None] # None filtered if trsf_data: if isinstance(trsf_data[0], tuple): # aug, perm sample_list = [ _ for sample in trsf_data for _ in sample ] data_trsfs_dict[trsf_i] = tuple([ Batch.from_data_list( sample_list[pair_i::len(trsf_data[0])]) for pair_i in range(len(trsf_data[0])) ]) # left, right = sample_list[::2], sample_list[1::2] # data_trsfs_dict[trsf_i]=(Batch.from_data_list(left), Batch.from_data_list(right)) else: # dest, mask data_trsfs_dict[trsf_i] = Batch.from_data_list( trsf_data) else: # transformed data is all none and filtered out data_trsfs_dict[trsf_i] = None return list(data_trsfs_dict.values()) elif isinstance(samples[0], tuple): # tuple sample_list = [_ for sample in samples for _ in sample] left, right = sample_list[::2], sample_list[1::2] return Batch.from_data_list(left), Batch.from_data_list(right) else: # torch_batch #samples = [sample for sample in samples if sample is not None] return Batch.from_data_list(samples) else: #empty return None
def __init__(self, X: torch.Tensor, edge_index: torch.Tensor, num_hops: int, n_rollout: int = 10, min_atoms: int = 3, c_puct: float = 10.0, expand_atoms: int = 14, high2low: bool = False, node_idx: int = None, score_func: Callable = None): """ graph is a networkX graph """ self.X = X self.edge_index = edge_index self.num_hops = num_hops self.data = Data(x=self.X, edge_index=self.edge_index) self.graph = to_networkx(self.data, to_undirected=True) self.data = Batch.from_data_list([self.data]) self.num_nodes = self.graph.number_of_nodes() self.score_func = score_func self.n_rollout = n_rollout self.min_atoms = min_atoms self.c_puct = c_puct self.expand_atoms = expand_atoms self.high2low = high2low # extract the sub-graph and change the node indices. if node_idx is not None: self.ori_node_idx = node_idx self.ori_graph = copy.copy(self.graph) x, edge_index, subset, edge_mask, kwargs = \ self.__subgraph__(node_idx, self.X, self.edge_index, self.num_hops) self.data = Batch.from_data_list( [Data(x=x, edge_index=edge_index)]) self.graph = self.ori_graph.subgraph(subset.tolist()) mapping = {int(v): k for k, v in enumerate(subset)} self.graph = nx.relabel_nodes(self.graph, mapping) self.node_idx = torch.where(subset == self.ori_node_idx)[0] self.num_nodes = self.graph.number_of_nodes() self.subset = subset self.root_coalition = sorted([node for node in range(self.num_nodes)]) self.MCTSNodeClass = partial(MCTSNode, data=self.data, ori_graph=self.graph, c_puct=self.c_puct) self.root = self.MCTSNodeClass(self.root_coalition) self.state_map = {str(self.root.coalition): self.root}
def construct_hidden_graph(self, bsize: int, num_agent: int, hidden_size: int): # Compute edge connections edge_index = torch.tensor(list(permutations(range(num_agent), 2)), dtype=torch.long) edge_index = edge_index.t().contiguous() # Shape: [2 x E], E = n^2 e = edge_index.shape[1] # U vector. |U|-dimensional 0-vector x = torch.zeros((num_agent, hidden_size), dtype=torch.float32, device=self.device) u = torch.zeros((1, hidden_size), dtype=torch.float32, device=self.device) edge_attr = torch.zeros((e, hidden_size), dtype=torch.float32, device=self.device) # Create list of Data objects, then call Batch.from_data_list() data_objs = [ Data(x=x.clone(), edge_index=edge_index, edge_attr=edge_attr.clone(), u=u.clone()) for _ in range(bsize) ] batch = Batch.from_data_list(data_objs).to(x.device) return batch
def collate(output_list): if isinstance(output_list[0], torch.Tensor): return torch.cat(output_list, dim=0) elif geometric and isinstance(output_list[0], Data): return Batch.from_data_list(output_list) else: return [collate(dim) for dim in zip(*output_list)]
def _obs(self) -> Tuple[Batch, List[List[int]]]: """ returns ------- Tuple[Batch, List[List[int]] The Batch object contains the Pytorch Geometric graph representing the molecule. The list of lists of integers is a list of all the torsions of the molecule, where each torsion is represented by a list of four integers, where the integers are the indices of the four atoms making up the torsion. """ mol = Chem.rdmolops.RemoveHs(self.mol) conf = mol.GetConformer() atoms = mol.GetAtoms() bonds = mol.GetBonds() node_features = [molecule_features.atom_type_CO(atom) + molecule_features.atom_coords(atom, conf) for atom in atoms] edge_indices = molecule_features.get_bond_pairs(mol) edge_attributes = [molecule_features.bond_type(bond) for bond in bonds] * 2 data = Data( x=torch.tensor(node_features, dtype=torch.float), edge_index=torch.tensor(edge_indices, dtype=torch.long), edge_attr=torch.tensor(edge_attributes,dtype=torch.float), pos=torch.Tensor(conf.GetPositions()) ) data = Center()(data) data = NormalizeRotation()(data) data.x[:,-3:] = data.pos data = Batch.from_data_list([data]) return data, self.nonring
def stack(inp): if type(inp[0]) == list: ret = [] for vs in zip(*inp): ret.append(stack(vs)) elif type(inp[0]) == dict: ret = {} for kvs in zip(*[x.items() for x in inp]): ks, vs = zip(*kvs) for k in ks: assert k == ks[0], "Key value mismatch." ret[k] = stack(vs) elif type(inp[0]) == torch.Tensor: new_t = pad_tensor(inp) ret = torch.stack(new_t, 0) elif type(inp[0]) == np.ndarray: new_t = pad_tensor([torch.from_numpy(x) for x in inp]) ret = torch.stack(new_t, 0) elif type(inp[0]) == str: ret = inp elif type(inp[0]) == Data: # Graph from torch.geometric, create a batch ret = Batch.from_data_list(inp) else: raise ValueError("Cannot handle type {}".format(type(inp[0]))) return ret
def forward(self, data): subgraph_data = subgraph_loader(data, k, super_node_size, num_tours, num_cpus) subgraphs = [ get_subgraph(data[subgraph_data.batch[i].item()], subgraph_data.subgraphs[i].squeeze()) for i in range(len(subgraph_data.subgraphs)) ] subgraphs_lst = [] for i in range(0, len(subgraphs), 500): subgraphs_b = Batch().from_data_list( subgraphs[i:i + min([500, len(subgraphs) - i])]) subgraphs_b = self.gnn_layer(subgraphs_b.x.cuda(), subgraphs_b.edge_index.cuda(), subgraphs_b.batch.cuda()) \ if next(self.parameters()).get_device() != -1 else self.gnn_layer(subgraphs_b.x, subgraphs_b.edge_index, subgraphs_b.batch) subgraphs_lst.append(subgraphs_b) subgraphs = torch.cat(subgraphs_lst, dim=0) subgraphs = self.output_layer(subgraphs) weights = subgraph_data.weights.cuda() if next( self.parameters()).get_device() != -1 else subgraph_data.weights batch = subgraph_data.batch.cuda() if next( self.parameters()).get_device() != -1 else subgraph_data.batch subgraphs = subgraphs * weights norm = global_add_pool(weights, batch) energy = global_add_pool(subgraphs, batch) return energy / norm
def _get_schema_graph_encoding( self, worlds: List[SpiderWorld], initial_graph_embeddings: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: max_num_entities = max([ len(world.db_context.knowledge_graph.entities) for world in worlds ]) batch_size = initial_graph_embeddings.size(0) graph_data_list = [] for batch_index, world in enumerate(worlds): x = initial_graph_embeddings[batch_index] adj_list = self._get_graph_adj_lists( initial_graph_embeddings.device, world, initial_graph_embeddings.size(1) - 1) graph_data = Data(x) for i, l in enumerate(adj_list): graph_data[f'edge_index_{i}'] = l graph_data_list.append(graph_data) batch = Batch.from_data_list(graph_data_list) gnn_output = self._gnn(batch.x, [ batch[f'edge_index_{i}'] for i in range(self._gnn.num_edge_types) ]) num_nodes = max_num_entities gnn_output = gnn_output.view(batch_size, num_nodes, -1) # entities_encodings = gnn_output entities_encodings = gnn_output[:, :max_num_entities] # global_node_encodings = gnn_output[:, max_num_entities] return entities_encodings
def _process(self, data_list): if len(data_list) == 0: return Data() data = Batch.from_data_list(data_list) delattr(data, "batch") delattr(data, "ptr") return data
def avg_pool(cluster, data, transform=None): r"""Pools and coarsens a graph given by the :class:`torch_geometric.data.Data` object according to the clustering defined in :attr:`cluster`. Final node features are defined by the *average* features of all nodes within the same cluster. See :meth:`torch_geometric.nn.pool.max_pool` for more details. Args: cluster (LongTensor): Cluster vector :math:`\mathbf{c} \in \{ 0, \ldots, N - 1 \}^N`, which assigns each node to a specific cluster. data (Data): Graph data object. transform (callable, optional): A function/transform that takes in the coarsened and pooled :obj:`torch_geometric.data.Data` object and returns a transformed version. (default: :obj:`None`) :rtype: :class:`torch_geometric.data.Data` """ cluster, perm = consecutive_cluster(cluster) x = None if data.x is None else _avg_pool_x(cluster, data.x) index, attr = pool_edge(cluster, data.edge_index, data.edge_attr) batch = None if data.batch is None else pool_batch(perm, data.batch) pos = None if data.pos is None else pool_pos(cluster, data.pos) data = Batch(batch=batch, x=x, edge_index=index, edge_attr=attr, pos=pos) if transform is not None: data = transform(data) return data
def _save_graphs(sharded, shard_num, out_dir): print(f'Processing shard {shard_num:}') shard = sharded.read_shard(shard_num) neighbors = sharded.read_shard(shard_num, 'neighbors') curr_idx = 0 for i, (ensemble_name, target_df) in enumerate(shard.groupby(['ensemble'])): sub_names, (bound1, bound2, _, _) = nb.get_subunits(target_df) positives = neighbors[neighbors.ensemble0 == ensemble_name] negatives = nb.get_negatives(positives, bound1, bound2) negatives['label'] = 0 labels = create_labels(positives, negatives, num_pos=10, neg_pos_ratio=1) for index, row in labels.iterrows(): label = float(row['label']) chain_res1 = row[['chain0', 'residue0']].values chain_res2 = row[['chain1', 'residue1']].values graph1 = df_to_graph(bound1, chain_res1, label) graph2 = df_to_graph(bound2, chain_res2, label) if (graph1 is None) or (graph2 is None): continue pair = Batch.from_data_list([graph1, graph2]) torch.save(pair, os.path.join(out_dir, f'data_{shard_num}_{curr_idx}.pt')) curr_idx += 1
def test_pair_data_batching(): class PairData(Data): def __inc__(self, key, value, *args, **kwargs): if key == 'edge_index_s': return self.x_s.size(0) if key == 'edge_index_t': return self.x_t.size(0) else: return super().__inc__(key, value, *args, **kwargs) x_s = torch.randn(5, 16) edge_index_s = torch.tensor([ [0, 0, 0, 0], [1, 2, 3, 4], ]) x_t = torch.randn(4, 16) edge_index_t = torch.tensor([ [0, 0, 0], [1, 2, 3], ]) data = PairData(x_s=x_s, edge_index_s=edge_index_s, x_t=x_t, edge_index_t=edge_index_t) batch = Batch.from_data_list([data, data]) assert torch.allclose(batch.x_s, torch.cat([x_s, x_s], dim=0)) assert batch.edge_index_s.tolist() == [[0, 0, 0, 0, 5, 5, 5, 5], [1, 2, 3, 4, 6, 7, 8, 9]] assert torch.allclose(batch.x_t, torch.cat([x_t, x_t], dim=0)) assert batch.edge_index_t.tolist() == [[0, 0, 0, 4, 4, 4], [1, 2, 3, 5, 6, 7]]
def buildGraph(feat, label): B = feat.shape[0] NoOfNodes = feat.shape[1] #feat.reshape(B,NoOfNodes,-1) edge_index = list(itertools.permutations(np.arange(0, NoOfNodes), 2)) edge_index = torch.LongTensor(edge_index).T listofData = [] for i in range(0, B): feat_arr = feat[i].detach().cpu().numpy().reshape(NoOfNodes, -1) edge_attr = np.asarray([ np.linalg.norm(a - b) for a, b in itertools.product(feat_arr, feat_arr) ]) # for a in feat_arr[i]: # for b in feat_arr[i]: # print(np.linalg.norm(a-b)) edge_attr = edge_attr[edge_attr > 0] edge_attr = torch.Tensor(edge_attr).view(-1) data = Data(x=torch.Tensor(feat_arr), edge_index=edge_index, edge_attr=edge_attr, y=label[i].view(-1)) listofData.append(data) batch = Batch().from_data_list(listofData) return batch
def validate(model, validate_loader, batch_size): model.eval() loss_fn = nn.MSELoss() loss_all = 0 for i in range(len(validate_loader) // batch_size): # conserve gpu memory try: del pred_, batch, x, edges, y, loss except: pass # ordered mini-batch batch = [ validate_loader[j] for j in range(i * batch_size, (i + 1) * batch_size) ] batch = Batch.from_data_list(batch).to(device=CUDA_DEVICE) x, edges, y = batch.x, batch.edge_index, batch.y pred_ = model(batch) loss = loss_fn(pred_, y) loss_all += loss.item() return loss_all / (len(validate_loader) // batch_size)
def explain_node(self, node_idx, x, edge_index, **kwargs): data = Batch.from_data_list([Data(x=x, edge_index=edge_index)]) data = data.to(self.device) with torch.no_grad(): _, prob, emb = self.get_model_output(data.x, data.edge_index) _, edge_mask = self.forward((data.x, emb, data.edge_index, 1.0), training=False) return edge_mask