Exemplo n.º 1
0
def test_batching_of_batches():
    data = Data(x=torch.randn(2, 16))
    batch = Batch.from_data_list([data, data])

    batch = Batch.from_data_list([batch, batch])
    assert len(batch) == 2
    assert batch.x[0:2].tolist() == data.x.tolist()
    assert batch.x[2:4].tolist() == data.x.tolist()
    assert batch.x[4:6].tolist() == data.x.tolist()
    assert batch.x[6:8].tolist() == data.x.tolist()
    assert batch.batch.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
Exemplo n.º 2
0
    def score(self):
        """
        Scoring.
        """
        print("\n\nModel evaluation.\n")
        self.model.eval()

        scores = np.empty(
            (len(self.testing_graphs), len(self.training_graphs)))
        ground_truth = np.empty(
            (len(self.testing_graphs), len(self.training_graphs)))
        prediction_mat = np.empty(
            (len(self.testing_graphs), len(self.training_graphs)))

        rho_list = []
        tau_list = []
        prec_at_10_list = []
        prec_at_20_list = []

        t = tqdm(total=len(self.testing_graphs) * len(self.training_graphs))

        for i, g in enumerate(self.testing_graphs):
            source_batch = Batch.from_data_list([g] *
                                                len(self.training_graphs))
            target_batch = Batch.from_data_list(self.training_graphs)

            data = self.transform((source_batch, target_batch))
            target = data["target"]
            ground_truth[i] = target
            prediction = self.model(data)
            prediction_mat[i] = prediction.detach().numpy()

            scores[i] = (F.mse_loss(prediction, target,
                                    reduction="none").detach().numpy())

            rho_list.append(
                calculate_ranking_correlation(spearmanr, prediction_mat[i],
                                              ground_truth[i]))
            tau_list.append(
                calculate_ranking_correlation(kendalltau, prediction_mat[i],
                                              ground_truth[i]))
            prec_at_10_list.append(
                calculate_prec_at_k(10, prediction_mat[i], ground_truth[i]))
            prec_at_20_list.append(
                calculate_prec_at_k(20, prediction_mat[i], ground_truth[i]))

            t.update(len(self.training_graphs))

        self.rho = np.mean(rho_list).item()
        self.tau = np.mean(tau_list).item()
        self.prec_at_10 = np.mean(prec_at_10_list).item()
        self.prec_at_20 = np.mean(prec_at_20_list).item()
        self.model_error = np.mean(scores).item()
        self.print_evaluation()
Exemplo n.º 3
0
 def from_data_list(data_list):
     r"""
     from a list of torch_points3d.datasets.registation.pair.Pair objects, create
     a batch
     Warning : follow_batch is not here yet...
     """
     assert isinstance(data_list[0], Pair)
     data_list_s, data_list_t = list(map(list, zip(*[data.to_data() for data in data_list])))
     batch_s = Batch.from_data_list(data_list_s)
     batch_t = Batch.from_data_list(data_list_t)
     return PairBatch.make_pair(batch_s, batch_t).contiguous()
Exemplo n.º 4
0
    def sample(self, batch_size):
        max_mem = min(self.mem_cntr, self.mem_size)
        batch = np.random.choice(max_mem, batch_size, replace=False)
        graphs_pre_batch = Batch.from_data_list(
            [self.graphs_pre[b] for b in batch])
        graphs_later_batch = Batch.from_data_list(
            [self.graphs_later[b] for b in batch])
        actions_batch = T.tensor([self.actions[b] for b in batch])
        rewards_batch = T.tensor([self.rewards[b] for b in batch])

        return graphs_pre_batch, graphs_later_batch, actions_batch, rewards_batch
Exemplo n.º 5
0
    def sample(self, batch_size):
        max_mem = min(self.mem_cntr, self.mem_size)
        p = np.array(self.rewards) / np.sum(self.rewards)
        batch = np.random.choice(self.mem_size, batch_size, replace=False, p=p)
        graphs_former_batch = Batch.from_data_list(
            [self.graphs_former[b] for b in batch])
        graphs_later_batch = Batch.from_data_list(
            [self.graphs_later[b] for b in batch])
        actions_batch = torch.Tensor([self.actions[b] for b in batch])
        rewards_batch = torch.Tensor([self.rewards[b] for b in batch])
        done_batch = torch.Tensor([self.done[b] for b in batch])

        return graphs_former_batch, graphs_later_batch, actions_batch, rewards_batch, done_batch
Exemplo n.º 6
0
def train(train_loader, model, criterion, optimizer, epoch, args):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(len(train_loader),
                             [batch_time, data_time, losses, top1, top5],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    for i, images in enumerate(train_loader):

        end = time.time()
        for im in images:
            im.edge_attr = None
        images_cls = Batch.from_data_list(images)
        im_q = Batch.from_data_list(random_augmentation(images))
        im_k = Batch.from_data_list(random_augmentation(images))

        data_time.update(time.time() - end)
        if args.gpu is not None:

            im_q = im_q.to(args.gpu)  #, non_blocking=True)
            im_k = im_k.to(args.gpu)  #, non_blocking=True)
            images_cls = images_cls.to(args.gpu)

        output, target, q_cls = model(im_q=im_q, im_k=im_k, image=images_cls)
        if args.gpu != None:
            target = target.to(args.gpu)
        loss = criterion(output, target)

        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), len(images))
        loss_list.append(loss.item())
        top1.update(acc1[0], len(images))
        top5.update(acc5[0], len(images))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
Exemplo n.º 7
0
    def __init__(self,
                 X: torch.Tensor,
                 edge_index: torch.Tensor,
                 num_hops: int,
                 n_rollout: int = 10,
                 min_atoms: int = 3,
                 c_puct: float = 10.0,
                 expand_atoms: int = 14,
                 high2low: bool = False,
                 node_idx: int = None,
                 score_func: Callable = None):
        """ graph is a networkX graph """
        self.X = X
        self.edge_index = edge_index
        self.num_hops = num_hops
        self.data = Data(x=self.X, edge_index=self.edge_index)
        self.graph = to_networkx(self.data, to_undirected=True)
        self.data = Batch.from_data_list([self.data])
        self.num_nodes = self.graph.number_of_nodes()
        self.score_func = score_func
        self.n_rollout = n_rollout
        self.min_atoms = min_atoms
        self.c_puct = c_puct
        self.expand_atoms = expand_atoms
        self.high2low = high2low

        # extract the sub-graph and change the node indices.
        if node_idx is not None:
            self.ori_node_idx = node_idx
            self.ori_graph = copy.copy(self.graph)
            x, edge_index, subset, edge_mask, kwargs = \
                self.__subgraph__(node_idx, self.X, self.edge_index, self.num_hops)
            self.data = Batch.from_data_list(
                [Data(x=x, edge_index=edge_index)])
            self.graph = self.ori_graph.subgraph(subset.tolist())
            mapping = {int(v): k for k, v in enumerate(subset)}
            self.graph = nx.relabel_nodes(self.graph, mapping)
            self.node_idx = torch.where(subset == self.ori_node_idx)[0]
            self.num_nodes = self.graph.number_of_nodes()
            self.subset = subset

        self.root_coalition = sorted([node for node in range(self.num_nodes)])
        self.MCTSNodeClass = partial(MCTSNode,
                                     data=self.data,
                                     ori_graph=self.graph,
                                     c_puct=self.c_puct)
        self.root = self.MCTSNodeClass(self.root_coalition)
        self.state_map = {str(self.root.coalition): self.root}
Exemplo n.º 8
0
def collate_fn(samples):
    #     print(samples)
    # filtering none
    samples = [sample for sample in samples if sample is not None]
    if samples:  # nonempty : tuple or torch_batch

        if isinstance(samples[0], list):  # list : multiple transform
            num_transforms = len(samples[0])
            flatten_list = [_ for sample in samples
                            for _ in sample]  # bs * num_transforms
            data_trsfs_dict = OrderedDict()
            for trsf_i in range(num_transforms):
                trsf_data = flatten_list[
                    trsf_i::
                    num_transforms]  # list or list of tuples(aug, perm)
                trsf_data = [data for data in trsf_data
                             if data is not None]  # None filtered
                if trsf_data:
                    if isinstance(trsf_data[0], tuple):  # aug, perm

                        sample_list = [
                            _ for sample in trsf_data for _ in sample
                        ]
                        data_trsfs_dict[trsf_i] = tuple([
                            Batch.from_data_list(
                                sample_list[pair_i::len(trsf_data[0])])
                            for pair_i in range(len(trsf_data[0]))
                        ])
#                         left, right = sample_list[::2], sample_list[1::2]
#                         data_trsfs_dict[trsf_i]=(Batch.from_data_list(left), Batch.from_data_list(right))
                    else:  # dest, mask
                        data_trsfs_dict[trsf_i] = Batch.from_data_list(
                            trsf_data)
                else:  # transformed data is all none and filtered out
                    data_trsfs_dict[trsf_i] = None

            return list(data_trsfs_dict.values())

        elif isinstance(samples[0], tuple):  # tuple
            sample_list = [_ for sample in samples for _ in sample]
            left, right = sample_list[::2], sample_list[1::2]
            return Batch.from_data_list(left), Batch.from_data_list(right)
        else:  # torch_batch
            #samples = [sample for sample in samples if sample is not None]
            return Batch.from_data_list(samples)

    else:  #empty
        return None
Exemplo n.º 9
0
def test_pair_data_batching():
    class PairData(Data):
        def __inc__(self, key, value, *args, **kwargs):
            if key == 'edge_index_s':
                return self.x_s.size(0)
            if key == 'edge_index_t':
                return self.x_t.size(0)
            else:
                return super().__inc__(key, value, *args, **kwargs)

    x_s = torch.randn(5, 16)
    edge_index_s = torch.tensor([
        [0, 0, 0, 0],
        [1, 2, 3, 4],
    ])
    x_t = torch.randn(4, 16)
    edge_index_t = torch.tensor([
        [0, 0, 0],
        [1, 2, 3],
    ])

    data = PairData(x_s=x_s,
                    edge_index_s=edge_index_s,
                    x_t=x_t,
                    edge_index_t=edge_index_t)
    batch = Batch.from_data_list([data, data])

    assert torch.allclose(batch.x_s, torch.cat([x_s, x_s], dim=0))
    assert batch.edge_index_s.tolist() == [[0, 0, 0, 0, 5, 5, 5, 5],
                                           [1, 2, 3, 4, 6, 7, 8, 9]]

    assert torch.allclose(batch.x_t, torch.cat([x_t, x_t], dim=0))
    assert batch.edge_index_t.tolist() == [[0, 0, 0, 4, 4, 4],
                                           [1, 2, 3, 5, 6, 7]]
Exemplo n.º 10
0
def validate(model, validate_loader, batch_size):
    model.eval()
    loss_fn = nn.MSELoss()
    loss_all = 0

    for i in range(len(validate_loader) // batch_size):

        # conserve gpu memory
        try:
            del pred_, batch, x, edges, y, loss
        except:
            pass

        # ordered mini-batch
        batch = [
            validate_loader[j]
            for j in range(i * batch_size, (i + 1) * batch_size)
        ]
        batch = Batch.from_data_list(batch).to(device=CUDA_DEVICE)

        x, edges, y = batch.x, batch.edge_index, batch.y
        pred_ = model(batch)

        loss = loss_fn(pred_, y)
        loss_all += loss.item()

    return loss_all / (len(validate_loader) // batch_size)
Exemplo n.º 11
0
 def sample_buffer(self, batch_size):
     max_mem = min(self.mem_cntr, self.mem_size)
     batch = np.random.choice(max_mem, batch_size, replace=False)
     graph_list = [self.graph_memory[b] for b in batch]
     keys = graph_list[0].keys
     
     return Batch.from_data_list(graph_list)
Exemplo n.º 12
0
 def explain_node(self, node_idx, x, edge_index, **kwargs):
     data = Batch.from_data_list([Data(x=x, edge_index=edge_index)])
     data = data.to(self.device)
     with torch.no_grad():
         _, prob, emb = self.get_model_output(data.x, data.edge_index)
         _, edge_mask = self.forward((data.x, emb, data.edge_index, 1.0), training=False)
     return edge_mask
Exemplo n.º 13
0
def test_parallel(model, loader, total, batch_size, loss_ftn_obj):
    model.eval()

    sum_loss = 0.
    t = tqdm.tqdm(enumerate(loader), total=total / batch_size)
    for i, data in t:

        # forward and loss
        if loss_ftn_obj.name == 'vae_loss':
            batch_output, mu, log_var = model(data)
            y = torch.cat([d.x for d in data]).to(device)
            batch_loss_item = loss_ftn_obj.loss_ftn(batch_output, y, mu,
                                                    log_var).item()
        elif loss_ftn_obj.name == 'emd_loss' or loss_ftn_obj.name == 'chamfer_loss':
            batch_output = model(data)
            data_batch = Batch.from_data_list(data).to(device)
            batch_loss = loss_ftn_obj.loss_ftn(batch_output, data_batch.x,
                                               data_batch.batch)
            batch_loss_item = batch_loss.mean().item()
        elif loss_ftn_obj.name == 'emd_loss_layer':
            _, batch_loss = model(data)
            batch_loss_item = batch_loss.mean().item()
        else:
            batch_output = model(data)
            y = torch.cat([d.x for d in data]).to(device)
            batch_loss_item = loss_ftn_obj.loss_ftn(batch_output, y).item()

        sum_loss += batch_loss_item
        t.set_description('eval loss = %.5f' % (batch_loss_item))
        t.refresh()  # to show immediately the update

    return sum_loss / (i + 1)
Exemplo n.º 14
0
def _one_test_case(layer_generator):
    x_dim, input_x_dim, edge_attr_dim, output_x_dim = np.random.randint(1, 100, size=4).tolist()
    output_x_dim = x_dim
    data = Batch.from_data_list([_generate_data(x_dim, input_x_dim, edge_attr_dim, output_x_dim)
                                 for _ in range(10)])
    layer = layer_generator(x_dim=x_dim, input_x_dim=input_x_dim,
                            output_x_dim=output_x_dim, edge_attr_dim=edge_attr_dim)
    layer = layer.to(data.x.device)
    output_x = layer(data)
    # print(data, output_x.shape, layer)

    # Test output dimensionality
    assert(output_x.size(1) == output_x_dim)

    # Test homogeneous
    if layer.homogeneous_flag and layer.module.gnn_module.__class__.__name__ not in ['GATConv', 'GINConv', 'EpsGINConv', 'MPNNConv']:
        s = np.random.rand()*1000.
        data.x, data.input_x, data.edge_attr = data.x*s, data.input_x*s, data.edge_attr*s
        assert(torch.max(torch.abs(output_x*s-layer(data))) < 1e-3)
        data.x, data.input_x, data.edge_attr = data.x/s, data.input_x/s, data.edge_attr/s

    # Test backward
    loss_hist = []
    optimizer = torch.optim.Adam(layer.parameters(), lr=1e-3, eps=1e-5)
    for _ in range(1000):
        optimizer.zero_grad()
        output_x = layer(data)
        loss = torch.sum(output_x**2)
        loss_hist.append(loss.item())
        loss.backward()
        optimizer.step()
    print('**',loss_hist[::100])
    if np.std(loss_hist) > 1e-4:
        corr = np.corrcoef(np.arange(len(loss_hist)), loss_hist,)[0,1]
        assert(corr < -1e-3)
Exemplo n.º 15
0
 def _process(self, data_list):
     if len(data_list) == 0:
         return Data()
     data = Batch.from_data_list(data_list)
     delattr(data, "batch")
     delattr(data, "ptr")
     return data
Exemplo n.º 16
0
    def __init__(self, node_idx: int, X: torch.Tensor, edge_index: torch.Tensor,
                 ori_graph: nx.Graph, n_rollout: int, min_atoms: int, c_puct: float,
                 expand_atoms: int, score_func=None, num_hops: int = 3):
        self.X = X
        self.edge_index = edge_index
        self.num_hops = num_hops
        self.ori_graph = ori_graph
        self.ori_node_idx = node_idx
        self.ori_num_nodes = self.ori_graph.number_of_nodes()

        self.n_rollout = n_rollout
        self.min_atoms = min_atoms
        self.c_puct = c_puct
        self.expand_atoms = expand_atoms
        self.score_func = score_func

        # extract the sub-graph and change the node indices.
        x, edge_index, subset, edge_mask, kwargs = \
            self.__subgraph__(node_idx, self.X, self.edge_index)
        self.data = Batch.from_data_list([Data(x=x, edge_index=edge_index)])
        self.graph = self.ori_graph.subgraph(subset.tolist())
        mapping = {int(v): k for k, v in enumerate(subset)}
        self.graph = nx.relabel_nodes(self.graph, mapping)
        self.node_idx = torch.where(subset == self.ori_node_idx)[0]
        self.num_nodes = self.graph.number_of_nodes()

        self.root_coalition = [i for i in range(self.num_nodes)]
        self.MCTSNodeClass = partial(MCTSNode, data=self.data, ori_graph=self.graph, c_puct=self.c_puct)

        self.root = self.MCTSNodeClass(self.root_coalition)
        self.state_map = {str(sorted(self.root.coalition)): self.root}
Exemplo n.º 17
0
def load_pyg_batch_from_network_list(network_list):
    data_list = []
    for network in network_list:
        data = load_pyg_data_from_network(network)
        data_list.append(data)
    batch = Batch.from_data_list(data_list)
    return batch
Exemplo n.º 18
0
    def plot_reconstructions(self, index=0, path=None, name=None):
        """Plot reconstruction bar chart with validation data.

        Args:
            index (int): optional. The index of the validation data to use. Default is
                0.
            path (str): optional. Path to save the plottings. Default is the current
                working directory.
            name (str): optional. Name of the saved plotting. Default is
                "reconstructions.png".
        """
        root = self._rooting(path)
        self._setup_models("eval")
        if name is None:
            filep = os.path.join(root, "reconstructions.png")
        else:
            filep = os.path.join(root, name)
        self._setup_models("eval")
        data = self.dataloader.val_loader.dataset[index]
        batch = Batch.from_data_list([data]).to(self.device)
        label = batch.y[0].to("cpu").detach()
        with torch.no_grad():
            encoder_out = self.encoder(batch)
            if isinstance(encoder_out, tuple):
                encoder_out, *_ = encoder_out
            out = self.decoder(encoder_out)
        out = torch.round(torch.sigmoid(out))[0].to("cpu").detach().numpy()
        fig, axes = plt.subplots(2, 1, figsize=(8.0, 12.0))
        ax1, ax2 = axes.flatten()
        ax1.bar(list(range(out.shape[0])), label)
        ax1.set_xlabel("PubChem Fingerprint")
        ax2.bar(list(range(out.shape[0])), out)
        ax2.set_xlabel("Reconstructed Fingerprint")
        fig.savefig(filep, dpi=300, bbox_inches="tight")
        plt.close()
Exemplo n.º 19
0
def get_pos_neg_pairs(data_list):
    num_data = len(data_list)
    pos_list = list()
    neg_list = list()
    labels = torch.tensor([x.y for x in data_list], dtype=torch.long)

    for i in range(0, num_data):
        same_labels = (labels == labels[i]).nonzero()
        pos_idx = same_labels[torch.randint(0, same_labels.shape[0], (1,)), 0]
        pos_list.append(data_list[pos_idx])

        diff_labels = (labels != labels[i]).nonzero()
        neg_idx = diff_labels[torch.randint(0, diff_labels.shape[0], (1,)), 0]
        neg_list.append(data_list[neg_idx])

    return Batch.from_data_list(pos_list), Batch.from_data_list(neg_list)
Exemplo n.º 20
0
def test_batch():
    torch_geometric.set_debug(True)

    x1 = torch.tensor([1, 2, 3], dtype=torch.float)
    e1 = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
    s1 = '1'
    x2 = torch.tensor([1, 2], dtype=torch.float)
    e2 = torch.tensor([[0, 1], [1, 0]])
    s2 = '2'

    data = Batch.from_data_list([Data(x1, e1, s=s1), Data(x2, e2, s=s2)])

    assert data.__repr__() == (
        'Batch(batch=[5], edge_index=[2, 6], ptr=[3], s=[2], x=[5])')
    assert len(data) == 5
    assert data.x.tolist() == [1, 2, 3, 1, 2]
    assert data.edge_index.tolist() == [[0, 1, 1, 2, 3, 4], [1, 0, 2, 1, 4, 3]]
    assert data.s == ['1', '2']
    assert data.batch.tolist() == [0, 0, 0, 1, 1]
    assert data.ptr.tolist() == [0, 3, 5]
    assert data.num_graphs == 2

    data_list = data.to_data_list()
    assert len(data_list) == 2
    assert len(data_list[0]) == 3
    assert data_list[0].x.tolist() == [1, 2, 3]
    assert data_list[0].edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]
    assert data_list[0].s == '1'
    assert len(data_list[1]) == 3
    assert data_list[1].x.tolist() == [1, 2]
    assert data_list[1].edge_index.tolist() == [[0, 1], [1, 0]]
    assert data_list[1].s == '2'

    torch_geometric.set_debug(True)
Exemplo n.º 21
0
def _save_graphs(sharded, shard_num, out_dir):
    print(f'Processing shard {shard_num:}')
    shard = sharded.read_shard(shard_num)
    neighbors = sharded.read_shard(shard_num, 'neighbors')

    curr_idx = 0
    for i, (ensemble_name, target_df) in enumerate(shard.groupby(['ensemble'])):

        sub_names, (bound1, bound2, _, _) = nb.get_subunits(target_df)
        positives = neighbors[neighbors.ensemble0 == ensemble_name]
        negatives = nb.get_negatives(positives, bound1, bound2)
        negatives['label'] = 0
        labels = create_labels(positives, negatives, num_pos=10, neg_pos_ratio=1)
        
        for index, row in labels.iterrows():
            label = float(row['label'])
            chain_res1 = row[['chain0', 'residue0']].values
            chain_res2 = row[['chain1', 'residue1']].values
            graph1 = df_to_graph(bound1, chain_res1, label)
            graph2 = df_to_graph(bound2, chain_res2, label)
            if (graph1 is None) or (graph2 is None):
                continue

            pair = Batch.from_data_list([graph1, graph2])
            torch.save(pair, os.path.join(out_dir, f'data_{shard_num}_{curr_idx}.pt'))
            curr_idx += 1
Exemplo n.º 22
0
    def _obs(self) -> Tuple[Batch, List[List[int]]]:
        """
        returns
        -------
        Tuple[Batch, List[List[int]]
            The Batch object contains the Pytorch Geometric graph representing the molecule. The list of lists of integers
            is a list of all the torsions of the molecule, where each torsion is represented by a list of four integers, where the integers
            are the indices of the four atoms making up the torsion.
        """
        mol = Chem.rdmolops.RemoveHs(self.mol)
        conf = mol.GetConformer()
        atoms = mol.GetAtoms()
        bonds = mol.GetBonds()

        node_features = [molecule_features.atom_type_CO(atom) + molecule_features.atom_coords(atom, conf) for atom in atoms]
        edge_indices = molecule_features.get_bond_pairs(mol)
        edge_attributes = [molecule_features.bond_type(bond) for bond in bonds] * 2


        data = Data(
                    x=torch.tensor(node_features, dtype=torch.float),
                    edge_index=torch.tensor(edge_indices, dtype=torch.long),
                    edge_attr=torch.tensor(edge_attributes,dtype=torch.float),
                    pos=torch.Tensor(conf.GetPositions())
                )

        data = Center()(data)
        data = NormalizeRotation()(data)
        data.x[:,-3:] = data.pos
        data = Batch.from_data_list([data])
        return data, self.nonring
Exemplo n.º 23
0
    def construct_hidden_graph(self, bsize: int, num_agent: int,
                               hidden_size: int):
        # Compute edge connections
        edge_index = torch.tensor(list(permutations(range(num_agent), 2)),
                                  dtype=torch.long)
        edge_index = edge_index.t().contiguous()  # Shape: [2 x E], E = n^2
        e = edge_index.shape[1]

        # U vector. |U|-dimensional 0-vector
        x = torch.zeros((num_agent, hidden_size),
                        dtype=torch.float32,
                        device=self.device)
        u = torch.zeros((1, hidden_size),
                        dtype=torch.float32,
                        device=self.device)
        edge_attr = torch.zeros((e, hidden_size),
                                dtype=torch.float32,
                                device=self.device)

        # Create list of Data objects, then call Batch.from_data_list()
        data_objs = [
            Data(x=x.clone(),
                 edge_index=edge_index,
                 edge_attr=edge_attr.clone(),
                 u=u.clone()) for _ in range(bsize)
        ]
        batch = Batch.from_data_list(data_objs).to(x.device)
        return batch
Exemplo n.º 24
0
    def _prepare_batch(self, batch):
        """Create batch data for GAT.

    Parameters
    ----------
    batch: Tuple
      The tuple are `(inputs, labels, weights)`.

    Returns
    -------
    inputs: torch_geometric.data.Batch
      A mini-batch graph data for PyTorch Geometric models.
    labels: List[torch.Tensor] or None
      The labels converted to torch.Tensor.
    weights: List[torch.Tensor] or None
      The weights for each sample or sample/task pair converted to torch.Tensor.
    """
        try:
            from torch_geometric.data import Batch
        except:
            raise ValueError(
                "This class requires PyTorch Geometric to be installed.")

        inputs, labels, weights = batch
        pyg_graphs = [graph.to_pyg_graph() for graph in inputs[0]]
        inputs = Batch.from_data_list(pyg_graphs)
        inputs = inputs.to(self.device)
        _, labels, weights = super(GATModel, self)._prepare_batch(
            ([], labels, weights))
        return inputs, labels, weights
Exemplo n.º 25
0
    def _get_schema_graph_encoding(
        self, worlds: List[SpiderWorld], initial_graph_embeddings: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        max_num_entities = max([
            len(world.db_context.knowledge_graph.entities) for world in worlds
        ])
        batch_size = initial_graph_embeddings.size(0)

        graph_data_list = []

        for batch_index, world in enumerate(worlds):
            x = initial_graph_embeddings[batch_index]

            adj_list = self._get_graph_adj_lists(
                initial_graph_embeddings.device, world,
                initial_graph_embeddings.size(1) - 1)
            graph_data = Data(x)
            for i, l in enumerate(adj_list):
                graph_data[f'edge_index_{i}'] = l
            graph_data_list.append(graph_data)

        batch = Batch.from_data_list(graph_data_list)

        gnn_output = self._gnn(batch.x, [
            batch[f'edge_index_{i}'] for i in range(self._gnn.num_edge_types)
        ])

        num_nodes = max_num_entities
        gnn_output = gnn_output.view(batch_size, num_nodes, -1)
        # entities_encodings = gnn_output
        entities_encodings = gnn_output[:, :max_num_entities]
        # global_node_encodings = gnn_output[:, max_num_entities]

        return entities_encodings
Exemplo n.º 26
0
    def forward(self, inputs, training=None):
        x, embed, edge_index, tmp = inputs
        nodesize = embed.shape[0]
        feature_dim = embed.shape[1]
        f1 = embed.unsqueeze(1).repeat(1, nodesize, 1).reshape(-1, feature_dim)
        f2 = embed.unsqueeze(0).repeat(nodesize, 1, 1).reshape(-1, feature_dim)

        # using the node embedding to calculate the edge weight
        f12self = torch.cat([f1, f2], dim=-1)
        h = f12self.to(self.device)
        for elayer in self.elayers:
            h = elayer(h)
        values = h.reshape(-1)
        values = self.concrete_sample(values, beta=tmp, training=training)
        self.mask_sigmoid = values.reshape(nodesize, nodesize)

        # set the symmetric edge weights
        sym_mask = (self.mask_sigmoid + self.mask_sigmoid.transpose(0, 1)) / 2
        edge_mask = sym_mask[edge_index[0], edge_index[1]]

        # inverse the weights before sigmoid in MessagePassing Module
        edge_mask = inv_sigmoid(edge_mask)
        self.__clear_masks__()
        self.__set_masks__(x, edge_index, edge_mask)

        # the model prediction with edge mask
        data = Batch.from_data_list([Data(x=x, edge_index=edge_index)])
        data.to(self.device)
        outputs = self.model(data)
        return outputs[1].squeeze(), edge_mask
Exemplo n.º 27
0
 def stack(inp):
     if type(inp[0]) == list:
         ret = []
         for vs in zip(*inp):
             ret.append(stack(vs))
     elif type(inp[0]) == dict:
         ret = {}
         for kvs in zip(*[x.items() for x in inp]):
             ks, vs = zip(*kvs)
             for k in ks:
                 assert k == ks[0], "Key value mismatch."
             ret[k] = stack(vs)
     elif type(inp[0]) == torch.Tensor:
         new_t = pad_tensor(inp)
         ret = torch.stack(new_t, 0)
     elif type(inp[0]) == np.ndarray:
         new_t = pad_tensor([torch.from_numpy(x) for x in inp])
         ret = torch.stack(new_t, 0)
     elif type(inp[0]) == str:
         ret = inp
     elif type(inp[0]) == Data:  # Graph from torch.geometric, create a batch
         ret = Batch.from_data_list(inp)
     else:
         raise ValueError("Cannot handle type {}".format(type(inp[0])))
     return ret
Exemplo n.º 28
0
def train_parallel(model, optimizer, loader, total, batch_size, loss_ftn_obj):
    model.train()

    sum_loss = 0.
    t = tqdm.tqdm(enumerate(loader), total=total / batch_size)
    for i, data in t:
        optimizer.zero_grad()

        if loss_ftn_obj.name == 'vae_loss':
            batch_output, mu, log_var = model(data)
            y = torch.cat([d.x for d in data]).to(device)
            batch_loss = loss_ftn_obj.loss_ftn(batch_output, y, mu, log_var)
        elif loss_ftn_obj.name == 'emd_loss' or loss_ftn_obj.name == 'chamfer_loss':
            batch_output = model(data)
            data_batch = Batch.from_data_list(data).to(device)
            batch_loss = loss_ftn_obj.loss_ftn(batch_output, data_batch.x,
                                               data_batch.batch)
            batch_loss = batch_loss.mean()
        elif loss_ftn_obj.name == 'emd_loss_layer':
            _, batch_loss = model(data)
            batch_loss = batch_loss.mean()
        else:
            batch_output = model(data)
            y = torch.cat([d.x for d in data]).to(device)
            batch_loss = loss_ftn_obj.loss_ftn(batch_output, y)

        batch_loss.backward()
        batch_loss_item = batch_loss.item()
        t.set_description('train loss = %.5f' % batch_loss_item)
        t.refresh()  # to show immediately the update
        sum_loss += batch_loss_item
        optimizer.step()

    return sum_loss / (i + 1)
Exemplo n.º 29
0
 def collate(output_list):
     if isinstance(output_list[0], torch.Tensor):
         return torch.cat(output_list, dim=0)
     elif geometric and isinstance(output_list[0], Data):
         return Batch.from_data_list(output_list)
     else:
         return [collate(dim) for dim in zip(*output_list)]
Exemplo n.º 30
0
 def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs):
     super(DataLoader, self).__init__(
         dataset,
         batch_size,
         shuffle,
         collate_fn=lambda data_list: Batch.from_data_list(data_list),
         **kwargs)