示例#1
0
 def __call__(self, data):
     x = data.x
     data.x = data.x[:, :5]
     data = TwoMalkin()(data)
     data = ConnectedThreeMalkin()(data)
     data.x = x
     return data
示例#2
0
 def __call__(self,data):
     data.x = torch.zeros((data.num_nodes, 1), dtype=torch.float)
     data = TwoMalkin()(data)
     data = ConnectedThreeMalkin()(data)
     data.x = degree(data.edge_index[0], data.num_nodes, dtype=torch.long)
     data.x = F.one_hot(data.x//self.div, num_classes=(degrees[self.args.dataset])//self.div+1).to(torch.float)
     return data
示例#3
0
 def __call__(self, data):
     data.x = torch.zeros((data.num_nodes, 1), dtype=torch.float)
     data = TwoMalkin()(data)
     data = ConnectedThreeMalkin()(data)
     data.x = degree(data.edge_index[0], data.num_nodes, dtype=torch.long)
     data.x = one_hot(data.x, 136, torch.float)
     return data
示例#4
0
 def __call__(self, data):
     if data.x is None:
         data.x = torch.zeros((data.num_nodes, 1), dtype=torch.float)
         data = TwoMalkin()(data)
         data = ConnectedThreeMalkin()(data)
         data.x = degree(data.edge_index[0],
                         data.num_nodes,
                         dtype=torch.long)
         data.x = F.one_hot(data.x,
                            num_classes=int(degrees[self.args.dataset])).to(
                                torch.float)
     else:
         x = data.x
         data = TwoMalkin()(data)
         data = ConnectedThreeMalkin()(data)
         data.x = x
     return data
args = parser.parse_args()


class MyFilter(object):
    def __call__(self, data):
        return not (data.num_nodes == 7 and data.num_edges == 12) and \
            data.num_nodes < 450


BATCH = 20
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data',
                '1-2-3-PROTEINS')
dataset = TUDataset(path,
                    name='PROTEINS',
                    pre_transform=T.Compose(
                        [TwoMalkin(), ConnectedThreeMalkin()]),
                    pre_filter=MyFilter())

perm = torch.randperm(len(dataset), dtype=torch.long)
dataset = dataset[perm]

dataset.data.iso_type_2 = torch.unique(dataset.data.iso_type_2, True, True)[1]
num_i_2 = dataset.data.iso_type_2.max().item() + 1
dataset.data.iso_type_2 = F.one_hot(dataset.data.iso_type_2,
                                    num_classes=num_i_2).to(torch.float)

dataset.data.iso_type_3 = torch.unique(dataset.data.iso_type_3, True, True)[1]
num_i_3 = dataset.data.iso_type_3.max().item() + 1
dataset.data.iso_type_3 = F.one_hot(dataset.data.iso_type_3,
                                    num_classes=num_i_3).to(torch.float)
示例#6
0

class MyPreTransform(object):
    def __call__(self, data):
        return data


BATCH = 32
path = osp.join(
    osp.dirname(osp.realpath(__file__)), '..', 'data', '1-2-3-MUTAG')
dataset = TUDataset(
    path,
    name='MUTAG',
    pre_transform=T.Compose(
        [MyPreTransform(),
         TwoMalkin(), ConnectedThreeLocal()]),
    pre_filter=MyFilter())

perm = torch.randperm(len(dataset), dtype=torch.long)
torch.save(perm, 'mutag_perm.pt')
perm = torch.load('mutag_perm.pt')
dataset = dataset[perm]

dataset.data.iso_type_2 = torch.unique(dataset.data.iso_type_2, True, True)[1]
num_i_2 = dataset.data.iso_type_2.max().item() + 1
dataset.data.iso_type_2 = F.one_hot(
    dataset.data.iso_type_2, num_classes=num_i_2).to(torch.float)

dataset.data.iso_type_3 = torch.unique(dataset.data.iso_type_3, True, True)[1]
num_i_3 = dataset.data.iso_type_3.max().item() + 1
dataset.data.iso_type_3 = F.one_hot(
示例#7
0

class MyPreTransform(object):
    def __call__(self, data):
        data.x = data.x[:, -3:]  # Only use node attributes.
        return data


BATCH = 20
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data',
                '1-2-3-PROTEINS')
dataset = TUDataset(path,
                    name='PROTEINS',
                    pre_transform=T.Compose([
                        MyPreTransform(),
                        TwoMalkin(),
                        ConnectedThreeMalkin()
                    ]),
                    pre_filter=MyFilter())

perm = torch.randperm(len(dataset), dtype=torch.long)
dataset = dataset[perm]

dataset.data.iso_type_2 = torch.unique(dataset.data.iso_type_2, True, True)[1]
num_i_2 = dataset.data.iso_type_2.max().item() + 1
dataset.data.iso_type_2 = F.one_hot(dataset.data.iso_type_2,
                                    num_classes=num_i_2).to(torch.float)

dataset.data.iso_type_3 = torch.unique(dataset.data.iso_type_3, True, True)[1]
num_i_3 = dataset.data.iso_type_3.max().item() + 1
dataset.data.iso_type_3 = F.one_hot(dataset.data.iso_type_3,