Пример #1
0
def test_batching_of_batches():
    data = Data(x=torch.randn(2, 16))
    batch = Batch.from_data_list([data, data])

    batch = Batch.from_data_list([batch, batch])
    assert len(batch) == 2
    assert batch.x[0:2].tolist() == data.x.tolist()
    assert batch.x[2:4].tolist() == data.x.tolist()
    assert batch.x[4:6].tolist() == data.x.tolist()
    assert batch.x[6:8].tolist() == data.x.tolist()
    assert batch.batch.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
Пример #2
0
 def from_data_list(data_list):
     r"""
     from a list of torch_points3d.datasets.registation.pair.Pair objects, create
     a batch
     Warning : follow_batch is not here yet...
     """
     assert isinstance(data_list[0], Pair)
     data_list_s, data_list_t = list(map(list, zip(*[data.to_data() for data in data_list])))
     batch_s = Batch.from_data_list(data_list_s)
     batch_t = Batch.from_data_list(data_list_t)
     return PairBatch.make_pair(batch_s, batch_t).contiguous()
Пример #3
0
    def sample(self, batch_size):
        max_mem = min(self.mem_cntr, self.mem_size)
        batch = np.random.choice(max_mem, batch_size, replace=False)
        graphs_pre_batch = Batch.from_data_list(
            [self.graphs_pre[b] for b in batch])
        graphs_later_batch = Batch.from_data_list(
            [self.graphs_later[b] for b in batch])
        actions_batch = T.tensor([self.actions[b] for b in batch])
        rewards_batch = T.tensor([self.rewards[b] for b in batch])

        return graphs_pre_batch, graphs_later_batch, actions_batch, rewards_batch
Пример #4
0
    def score(self):
        """
        Scoring.
        """
        print("\n\nModel evaluation.\n")
        self.model.eval()

        scores = np.empty(
            (len(self.testing_graphs), len(self.training_graphs)))
        ground_truth = np.empty(
            (len(self.testing_graphs), len(self.training_graphs)))
        prediction_mat = np.empty(
            (len(self.testing_graphs), len(self.training_graphs)))

        rho_list = []
        tau_list = []
        prec_at_10_list = []
        prec_at_20_list = []

        t = tqdm(total=len(self.testing_graphs) * len(self.training_graphs))

        for i, g in enumerate(self.testing_graphs):
            source_batch = Batch.from_data_list([g] *
                                                len(self.training_graphs))
            target_batch = Batch.from_data_list(self.training_graphs)

            data = self.transform((source_batch, target_batch))
            target = data["target"]
            ground_truth[i] = target
            prediction = self.model(data)
            prediction_mat[i] = prediction.detach().numpy()

            scores[i] = (F.mse_loss(prediction, target,
                                    reduction="none").detach().numpy())

            rho_list.append(
                calculate_ranking_correlation(spearmanr, prediction_mat[i],
                                              ground_truth[i]))
            tau_list.append(
                calculate_ranking_correlation(kendalltau, prediction_mat[i],
                                              ground_truth[i]))
            prec_at_10_list.append(
                calculate_prec_at_k(10, prediction_mat[i], ground_truth[i]))
            prec_at_20_list.append(
                calculate_prec_at_k(20, prediction_mat[i], ground_truth[i]))

            t.update(len(self.training_graphs))

        self.rho = np.mean(rho_list).item()
        self.tau = np.mean(tau_list).item()
        self.prec_at_10 = np.mean(prec_at_10_list).item()
        self.prec_at_20 = np.mean(prec_at_20_list).item()
        self.model_error = np.mean(scores).item()
        self.print_evaluation()
Пример #5
0
 def set_input(self, data, device):
     self.input = Batch(pos=data.pos, x=data.x, batch=data.batch).to(device)
     if hasattr(data, "pos_target"):
         self.input_target = Batch(pos=data.pos_target,
                                   x=data.x_target,
                                   batch=data.batch_target).to(device)
         self.match = data.pair_ind.to(torch.long).to(device)
         self.size_match = data.size_pair_ind.to(torch.long).to(device)
     else:
         self.match = data.pair_ind.to(torch.long).to(device)
         self.size_match = data.size_pair_ind.to(torch.long).to(device)
Пример #6
0
    def keep_human_object_interactions(input_graph: Batch, target_graph: Batch,
                                       filename: str, *, human_class: int):
        subjs = input_graph.object_classes[input_graph.relation_indexes[0]]
        keep = subjs == human_class

        input_graph.n_edges = keep.sum().item()
        input_graph.relation_indexes = input_graph.relation_indexes[:, keep]
        input_graph.relation_linear_features = input_graph.relation_linear_features[
            keep]

        return input_graph, target_graph
Пример #7
0
 def collate(data_list):
     batch = Batch()
     batch.batch = []
     for key in data_list[0].keys:
         batch[key] = default_collate([d[key] for d in data_list])
     for i, data in enumerate(data_list):
         num_nodes = data.num_nodes
         if num_nodes is not None:
             item = torch.full((num_nodes, ), i, dtype=torch.long)
             batch.batch.append(item)
     batch.batch = torch.cat(batch.batch, dim=0)
     return batch
Пример #8
0
def collate_fn_withpad(data_list):
    '''
    Modified based on PyTorch-Geometric's implementation
    :param data_list:
    :return:
    '''
    keys = [set(data.keys) for data in data_list]
    keys = list(set.union(*keys))
    assert 'batch' not in keys

    batch = Batch()

    for key in keys:
        batch[key] = []
    batch.batch = []

    cumsum = 0
    for i, data in enumerate(data_list):
        num_nodes = data.num_nodes
        batch.batch.append(torch.full((num_nodes, ), i, dtype=torch.long))
        for key in data.keys:
            item = data[key]
            item = item + cumsum if data.__cumsum__(key, item) else item
            batch[key].append(item)
        cumsum += num_nodes

    for key in keys:
        item = batch[key][0]
        if torch.is_tensor(item):
            if (len(item.shape) == 3):
                tlens = [x.shape[1] for x in batch[key]]
                maxtlens = np.max(tlens)
                to_cat = []
                for x in batch[key]:
                    to_cat.append(
                        torch.cat([
                            x,
                            x.new_zeros(x.shape[0], maxtlens - x.shape[1],
                                        x.shape[2])
                        ],
                                  dim=1))
                batch[key] = torch.cat(to_cat, dim=0)
                if 'tlens' not in batch.keys:
                    batch['tlens'] = item.new_tensor(tlens, dtype=torch.long)
            else:
                batch[key] = torch.cat(batch[key],
                                       dim=data_list[0].__cat_dim__(key, item))
        elif isinstance(item, int) or isinstance(item, float):
            batch[key] = torch.tensor(batch[key])
        else:
            raise ValueError('Unsupported attribute type.')
    batch.batch = torch.cat(batch.batch, dim=-1)
    return batch.contiguous()
Пример #9
0
 def forward(self, data, **kwargs):
     batch_obj = Batch()
     x, pos, batch = data.x, data.pos, data.batch
     if pos is not None:
         x = self.nn(torch.cat([x, pos], dim=1))
     x = self.pool(x, batch)
     batch_obj.x = x
     if pos is not None:
         batch_obj.pos = pos.new_zeros((x.size(0), 3))
     batch_obj.batch = torch.arange(x.size(0), device=batch.device)
     copy_from_to(data, batch_obj)
     return batch_obj
Пример #10
0
    def sample(self, batch_size):
        max_mem = min(self.mem_cntr, self.mem_size)
        p = np.array(self.rewards) / np.sum(self.rewards)
        batch = np.random.choice(self.mem_size, batch_size, replace=False, p=p)
        graphs_former_batch = Batch.from_data_list(
            [self.graphs_former[b] for b in batch])
        graphs_later_batch = Batch.from_data_list(
            [self.graphs_later[b] for b in batch])
        actions_batch = torch.Tensor([self.actions[b] for b in batch])
        rewards_batch = torch.Tensor([self.rewards[b] for b in batch])
        done_batch = torch.Tensor([self.done[b] for b in batch])

        return graphs_former_batch, graphs_later_batch, actions_batch, rewards_batch, done_batch
Пример #11
0
    def forward(self, batch):
        r"""Forward computation which computes the raw edge score, normalizes
        it
        """

        data_list = Batch.to_data_list(batch)
        data_list_out = []
        for data in data_list:
            new_edge_attr = softmax(data.edge_attr, data.edge_index[0])
            data.edge_attr = new_edge_attr
            data_list_out.append(data)
        batch = Batch.from_data_list(data_list_out)
        return batch
Пример #12
0
def split_vr_batch(relations: Batch) -> List[Data]:
    # Hack to force torch_geometric to accept our graphs
    relations.x = relations.object_boxes
    relations.__slices__["x"] = relations.__slices__["object_boxes"]
    result = []
    for r in relations.to_data_list():
        r.x = None
        r.n_nodes = r.n_nodes.item()
        r.n_edges = r.n_edges.item()
        result.append(r)
    relations.x = None
    del relations.__slices__["x"]
    return result
Пример #13
0
def train(train_loader, model, criterion, optimizer, epoch, args):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(len(train_loader),
                             [batch_time, data_time, losses, top1, top5],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    for i, images in enumerate(train_loader):

        end = time.time()
        for im in images:
            im.edge_attr = None
        images_cls = Batch.from_data_list(images)
        im_q = Batch.from_data_list(random_augmentation(images))
        im_k = Batch.from_data_list(random_augmentation(images))

        data_time.update(time.time() - end)
        if args.gpu is not None:

            im_q = im_q.to(args.gpu)  #, non_blocking=True)
            im_k = im_k.to(args.gpu)  #, non_blocking=True)
            images_cls = images_cls.to(args.gpu)

        output, target, q_cls = model(im_q=im_q, im_k=im_k, image=images_cls)
        if args.gpu != None:
            target = target.to(args.gpu)
        loss = criterion(output, target)

        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), len(images))
        loss_list.append(loss.item())
        top1.update(acc1[0], len(images))
        top5.update(acc5[0], len(images))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Пример #14
0
 def make_bce_and_rank_targets(input_graph: Batch, target_graph: Batch,
                               filename: str, *, num_classes):
     """Binary and rank encoding of unique predicates"""
     unique_predicates = torch.unique(target_graph.predicate_classes,
                                      sorted=False)
     target_graph.predicate_bce = (torch.zeros(num_classes,
                                               dtype=torch.float).scatter_(
                                                   dim=0,
                                                   index=unique_predicates,
                                                   value=1.0).view(1, -1))
     target_graph.predicate_rank = torch.constant_pad_nd(
         unique_predicates,
         pad=(0, num_classes - len(unique_predicates)),
         value=-1).view(1, -1)
     return input_graph, target_graph
Пример #15
0
    def forward(self, batch):

        batch = Batch.to_data_list(batch)
        batch_size = len(batch)
        n_chans = batch[0].x.shape[-1]
        edge_attrs = torch.stack([d.edge_attr.t() for d in batch])
        edge_attrs_out = self.conv(edge_attrs)

        edge_attrs_out = torch.exp(-edge_attrs_out)

        # put new attributes in graphs
        for i in range(batch_size):
            batch[i].edge_attr = edge_attrs_out[i, ...].t()

        return Batch.from_data_list(batch)
Пример #16
0
def collate_fn(samples):
    #     print(samples)
    # filtering none
    samples = [sample for sample in samples if sample is not None]
    if samples:  # nonempty : tuple or torch_batch

        if isinstance(samples[0], list):  # list : multiple transform
            num_transforms = len(samples[0])
            flatten_list = [_ for sample in samples
                            for _ in sample]  # bs * num_transforms
            data_trsfs_dict = OrderedDict()
            for trsf_i in range(num_transforms):
                trsf_data = flatten_list[
                    trsf_i::
                    num_transforms]  # list or list of tuples(aug, perm)
                trsf_data = [data for data in trsf_data
                             if data is not None]  # None filtered
                if trsf_data:
                    if isinstance(trsf_data[0], tuple):  # aug, perm

                        sample_list = [
                            _ for sample in trsf_data for _ in sample
                        ]
                        data_trsfs_dict[trsf_i] = tuple([
                            Batch.from_data_list(
                                sample_list[pair_i::len(trsf_data[0])])
                            for pair_i in range(len(trsf_data[0]))
                        ])
#                         left, right = sample_list[::2], sample_list[1::2]
#                         data_trsfs_dict[trsf_i]=(Batch.from_data_list(left), Batch.from_data_list(right))
                    else:  # dest, mask
                        data_trsfs_dict[trsf_i] = Batch.from_data_list(
                            trsf_data)
                else:  # transformed data is all none and filtered out
                    data_trsfs_dict[trsf_i] = None

            return list(data_trsfs_dict.values())

        elif isinstance(samples[0], tuple):  # tuple
            sample_list = [_ for sample in samples for _ in sample]
            left, right = sample_list[::2], sample_list[1::2]
            return Batch.from_data_list(left), Batch.from_data_list(right)
        else:  # torch_batch
            #samples = [sample for sample in samples if sample is not None]
            return Batch.from_data_list(samples)

    else:  #empty
        return None
Пример #17
0
    def __init__(self,
                 X: torch.Tensor,
                 edge_index: torch.Tensor,
                 num_hops: int,
                 n_rollout: int = 10,
                 min_atoms: int = 3,
                 c_puct: float = 10.0,
                 expand_atoms: int = 14,
                 high2low: bool = False,
                 node_idx: int = None,
                 score_func: Callable = None):
        """ graph is a networkX graph """
        self.X = X
        self.edge_index = edge_index
        self.num_hops = num_hops
        self.data = Data(x=self.X, edge_index=self.edge_index)
        self.graph = to_networkx(self.data, to_undirected=True)
        self.data = Batch.from_data_list([self.data])
        self.num_nodes = self.graph.number_of_nodes()
        self.score_func = score_func
        self.n_rollout = n_rollout
        self.min_atoms = min_atoms
        self.c_puct = c_puct
        self.expand_atoms = expand_atoms
        self.high2low = high2low

        # extract the sub-graph and change the node indices.
        if node_idx is not None:
            self.ori_node_idx = node_idx
            self.ori_graph = copy.copy(self.graph)
            x, edge_index, subset, edge_mask, kwargs = \
                self.__subgraph__(node_idx, self.X, self.edge_index, self.num_hops)
            self.data = Batch.from_data_list(
                [Data(x=x, edge_index=edge_index)])
            self.graph = self.ori_graph.subgraph(subset.tolist())
            mapping = {int(v): k for k, v in enumerate(subset)}
            self.graph = nx.relabel_nodes(self.graph, mapping)
            self.node_idx = torch.where(subset == self.ori_node_idx)[0]
            self.num_nodes = self.graph.number_of_nodes()
            self.subset = subset

        self.root_coalition = sorted([node for node in range(self.num_nodes)])
        self.MCTSNodeClass = partial(MCTSNode,
                                     data=self.data,
                                     ori_graph=self.graph,
                                     c_puct=self.c_puct)
        self.root = self.MCTSNodeClass(self.root_coalition)
        self.state_map = {str(self.root.coalition): self.root}
Пример #18
0
    def construct_hidden_graph(self, bsize: int, num_agent: int,
                               hidden_size: int):
        # Compute edge connections
        edge_index = torch.tensor(list(permutations(range(num_agent), 2)),
                                  dtype=torch.long)
        edge_index = edge_index.t().contiguous()  # Shape: [2 x E], E = n^2
        e = edge_index.shape[1]

        # U vector. |U|-dimensional 0-vector
        x = torch.zeros((num_agent, hidden_size),
                        dtype=torch.float32,
                        device=self.device)
        u = torch.zeros((1, hidden_size),
                        dtype=torch.float32,
                        device=self.device)
        edge_attr = torch.zeros((e, hidden_size),
                                dtype=torch.float32,
                                device=self.device)

        # Create list of Data objects, then call Batch.from_data_list()
        data_objs = [
            Data(x=x.clone(),
                 edge_index=edge_index,
                 edge_attr=edge_attr.clone(),
                 u=u.clone()) for _ in range(bsize)
        ]
        batch = Batch.from_data_list(data_objs).to(x.device)
        return batch
Пример #19
0
 def collate(output_list):
     if isinstance(output_list[0], torch.Tensor):
         return torch.cat(output_list, dim=0)
     elif geometric and isinstance(output_list[0], Data):
         return Batch.from_data_list(output_list)
     else:
         return [collate(dim) for dim in zip(*output_list)]
Пример #20
0
    def _obs(self) -> Tuple[Batch, List[List[int]]]:
        """
        returns
        -------
        Tuple[Batch, List[List[int]]
            The Batch object contains the Pytorch Geometric graph representing the molecule. The list of lists of integers
            is a list of all the torsions of the molecule, where each torsion is represented by a list of four integers, where the integers
            are the indices of the four atoms making up the torsion.
        """
        mol = Chem.rdmolops.RemoveHs(self.mol)
        conf = mol.GetConformer()
        atoms = mol.GetAtoms()
        bonds = mol.GetBonds()

        node_features = [molecule_features.atom_type_CO(atom) + molecule_features.atom_coords(atom, conf) for atom in atoms]
        edge_indices = molecule_features.get_bond_pairs(mol)
        edge_attributes = [molecule_features.bond_type(bond) for bond in bonds] * 2


        data = Data(
                    x=torch.tensor(node_features, dtype=torch.float),
                    edge_index=torch.tensor(edge_indices, dtype=torch.long),
                    edge_attr=torch.tensor(edge_attributes,dtype=torch.float),
                    pos=torch.Tensor(conf.GetPositions())
                )

        data = Center()(data)
        data = NormalizeRotation()(data)
        data.x[:,-3:] = data.pos
        data = Batch.from_data_list([data])
        return data, self.nonring
Пример #21
0
 def stack(inp):
     if type(inp[0]) == list:
         ret = []
         for vs in zip(*inp):
             ret.append(stack(vs))
     elif type(inp[0]) == dict:
         ret = {}
         for kvs in zip(*[x.items() for x in inp]):
             ks, vs = zip(*kvs)
             for k in ks:
                 assert k == ks[0], "Key value mismatch."
             ret[k] = stack(vs)
     elif type(inp[0]) == torch.Tensor:
         new_t = pad_tensor(inp)
         ret = torch.stack(new_t, 0)
     elif type(inp[0]) == np.ndarray:
         new_t = pad_tensor([torch.from_numpy(x) for x in inp])
         ret = torch.stack(new_t, 0)
     elif type(inp[0]) == str:
         ret = inp
     elif type(inp[0]) == Data:  # Graph from torch.geometric, create a batch
         ret = Batch.from_data_list(inp)
     else:
         raise ValueError("Cannot handle type {}".format(type(inp[0])))
     return ret
Пример #22
0
 def forward(self, data):
     subgraph_data = subgraph_loader(data, k, super_node_size, num_tours,
                                     num_cpus)
     subgraphs = [
         get_subgraph(data[subgraph_data.batch[i].item()],
                      subgraph_data.subgraphs[i].squeeze())
         for i in range(len(subgraph_data.subgraphs))
     ]
     subgraphs_lst = []
     for i in range(0, len(subgraphs), 500):
         subgraphs_b = Batch().from_data_list(
             subgraphs[i:i + min([500, len(subgraphs) - i])])
         subgraphs_b = self.gnn_layer(subgraphs_b.x.cuda(), subgraphs_b.edge_index.cuda(), subgraphs_b.batch.cuda()) \
         if next(self.parameters()).get_device() != -1 else self.gnn_layer(subgraphs_b.x, subgraphs_b.edge_index, subgraphs_b.batch)
         subgraphs_lst.append(subgraphs_b)
     subgraphs = torch.cat(subgraphs_lst, dim=0)
     subgraphs = self.output_layer(subgraphs)
     weights = subgraph_data.weights.cuda() if next(
         self.parameters()).get_device() != -1 else subgraph_data.weights
     batch = subgraph_data.batch.cuda() if next(
         self.parameters()).get_device() != -1 else subgraph_data.batch
     subgraphs = subgraphs * weights
     norm = global_add_pool(weights, batch)
     energy = global_add_pool(subgraphs, batch)
     return energy / norm
Пример #23
0
    def _get_schema_graph_encoding(
        self, worlds: List[SpiderWorld], initial_graph_embeddings: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        max_num_entities = max([
            len(world.db_context.knowledge_graph.entities) for world in worlds
        ])
        batch_size = initial_graph_embeddings.size(0)

        graph_data_list = []

        for batch_index, world in enumerate(worlds):
            x = initial_graph_embeddings[batch_index]

            adj_list = self._get_graph_adj_lists(
                initial_graph_embeddings.device, world,
                initial_graph_embeddings.size(1) - 1)
            graph_data = Data(x)
            for i, l in enumerate(adj_list):
                graph_data[f'edge_index_{i}'] = l
            graph_data_list.append(graph_data)

        batch = Batch.from_data_list(graph_data_list)

        gnn_output = self._gnn(batch.x, [
            batch[f'edge_index_{i}'] for i in range(self._gnn.num_edge_types)
        ])

        num_nodes = max_num_entities
        gnn_output = gnn_output.view(batch_size, num_nodes, -1)
        # entities_encodings = gnn_output
        entities_encodings = gnn_output[:, :max_num_entities]
        # global_node_encodings = gnn_output[:, max_num_entities]

        return entities_encodings
Пример #24
0
 def _process(self, data_list):
     if len(data_list) == 0:
         return Data()
     data = Batch.from_data_list(data_list)
     delattr(data, "batch")
     delattr(data, "ptr")
     return data
Пример #25
0
def avg_pool(cluster, data, transform=None):
    r"""Pools and coarsens a graph given by the
    :class:`torch_geometric.data.Data` object according to the clustering
    defined in :attr:`cluster`.
    Final node features are defined by the *average* features of all nodes
    within the same cluster.
    See :meth:`torch_geometric.nn.pool.max_pool` for more details.

    Args:
        cluster (LongTensor): Cluster vector :math:`\mathbf{c} \in \{ 0,
            \ldots, N - 1 \}^N`, which assigns each node to a specific cluster.
        data (Data): Graph data object.
        transform (callable, optional): A function/transform that takes in the
            coarsened and pooled :obj:`torch_geometric.data.Data` object and
            returns a transformed version. (default: :obj:`None`)

    :rtype: :class:`torch_geometric.data.Data`
    """
    cluster, perm = consecutive_cluster(cluster)

    x = None if data.x is None else _avg_pool_x(cluster, data.x)
    index, attr = pool_edge(cluster, data.edge_index, data.edge_attr)
    batch = None if data.batch is None else pool_batch(perm, data.batch)
    pos = None if data.pos is None else pool_pos(cluster, data.pos)

    data = Batch(batch=batch, x=x, edge_index=index, edge_attr=attr, pos=pos)

    if transform is not None:
        data = transform(data)

    return data
Пример #26
0
def _save_graphs(sharded, shard_num, out_dir):
    print(f'Processing shard {shard_num:}')
    shard = sharded.read_shard(shard_num)
    neighbors = sharded.read_shard(shard_num, 'neighbors')

    curr_idx = 0
    for i, (ensemble_name, target_df) in enumerate(shard.groupby(['ensemble'])):

        sub_names, (bound1, bound2, _, _) = nb.get_subunits(target_df)
        positives = neighbors[neighbors.ensemble0 == ensemble_name]
        negatives = nb.get_negatives(positives, bound1, bound2)
        negatives['label'] = 0
        labels = create_labels(positives, negatives, num_pos=10, neg_pos_ratio=1)
        
        for index, row in labels.iterrows():
            label = float(row['label'])
            chain_res1 = row[['chain0', 'residue0']].values
            chain_res2 = row[['chain1', 'residue1']].values
            graph1 = df_to_graph(bound1, chain_res1, label)
            graph2 = df_to_graph(bound2, chain_res2, label)
            if (graph1 is None) or (graph2 is None):
                continue

            pair = Batch.from_data_list([graph1, graph2])
            torch.save(pair, os.path.join(out_dir, f'data_{shard_num}_{curr_idx}.pt'))
            curr_idx += 1
Пример #27
0
def test_pair_data_batching():
    class PairData(Data):
        def __inc__(self, key, value, *args, **kwargs):
            if key == 'edge_index_s':
                return self.x_s.size(0)
            if key == 'edge_index_t':
                return self.x_t.size(0)
            else:
                return super().__inc__(key, value, *args, **kwargs)

    x_s = torch.randn(5, 16)
    edge_index_s = torch.tensor([
        [0, 0, 0, 0],
        [1, 2, 3, 4],
    ])
    x_t = torch.randn(4, 16)
    edge_index_t = torch.tensor([
        [0, 0, 0],
        [1, 2, 3],
    ])

    data = PairData(x_s=x_s,
                    edge_index_s=edge_index_s,
                    x_t=x_t,
                    edge_index_t=edge_index_t)
    batch = Batch.from_data_list([data, data])

    assert torch.allclose(batch.x_s, torch.cat([x_s, x_s], dim=0))
    assert batch.edge_index_s.tolist() == [[0, 0, 0, 0, 5, 5, 5, 5],
                                           [1, 2, 3, 4, 6, 7, 8, 9]]

    assert torch.allclose(batch.x_t, torch.cat([x_t, x_t], dim=0))
    assert batch.edge_index_t.tolist() == [[0, 0, 0, 4, 4, 4],
                                           [1, 2, 3, 5, 6, 7]]
Пример #28
0
def buildGraph(feat, label):
    B = feat.shape[0]
    NoOfNodes = feat.shape[1]
    #feat.reshape(B,NoOfNodes,-1)
    edge_index = list(itertools.permutations(np.arange(0, NoOfNodes), 2))
    edge_index = torch.LongTensor(edge_index).T
    listofData = []
    for i in range(0, B):
        feat_arr = feat[i].detach().cpu().numpy().reshape(NoOfNodes, -1)
        edge_attr = np.asarray([
            np.linalg.norm(a - b)
            for a, b in itertools.product(feat_arr, feat_arr)
        ])
        # for a in feat_arr[i]:
        #     for b in feat_arr[i]:
        #         print(np.linalg.norm(a-b))
        edge_attr = edge_attr[edge_attr > 0]
        edge_attr = torch.Tensor(edge_attr).view(-1)
        data = Data(x=torch.Tensor(feat_arr),
                    edge_index=edge_index,
                    edge_attr=edge_attr,
                    y=label[i].view(-1))
        listofData.append(data)
    batch = Batch().from_data_list(listofData)

    return batch
Пример #29
0
def validate(model, validate_loader, batch_size):
    model.eval()
    loss_fn = nn.MSELoss()
    loss_all = 0

    for i in range(len(validate_loader) // batch_size):

        # conserve gpu memory
        try:
            del pred_, batch, x, edges, y, loss
        except:
            pass

        # ordered mini-batch
        batch = [
            validate_loader[j]
            for j in range(i * batch_size, (i + 1) * batch_size)
        ]
        batch = Batch.from_data_list(batch).to(device=CUDA_DEVICE)

        x, edges, y = batch.x, batch.edge_index, batch.y
        pred_ = model(batch)

        loss = loss_fn(pred_, y)
        loss_all += loss.item()

    return loss_all / (len(validate_loader) // batch_size)
Пример #30
0
 def explain_node(self, node_idx, x, edge_index, **kwargs):
     data = Batch.from_data_list([Data(x=x, edge_index=edge_index)])
     data = data.to(self.device)
     with torch.no_grad():
         _, prob, emb = self.get_model_output(data.x, data.edge_index)
         _, edge_mask = self.forward((data.x, emb, data.edge_index, 1.0), training=False)
     return edge_mask