Exemplo n.º 1
0
    def process(self):
        with open(self.raw_paths[0], 'r') as f:
            data_list = [parse_sdf(x) for x in f.read().split('$$$$\n')[:-1]]

        with open(self.raw_paths[1], 'r') as f:
            y = parse_txt_array(f.read().split('\n')[1:-1], ',', 4, 16)

        for i, data in enumerate(data_list):
            data.y = y[i].view(1, -1)

        # Remove invalid data.
        with open(self.raw_paths[2], 'r') as f:
            invalid = f.read().split('\n')[9:-2]
            invalid = parse_txt_array(invalid, end=1, dtype=torch.long)
            valid = torch.ones(len(data_list), dtype=torch.uint8)
            valid[invalid] = 0
            valid = valid.nonzero()
        data_list = [data_list[idx] for idx in valid]

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])
def parse_off(src):
    # Some files may contain a bug and do not have a carriage return after OFF.
    if src[0] == 'OFF':
        src = src[1:]
    else:
        src[0] = src[0][3:]

    num_nodes, num_faces = [int(item) for item in src[0].split()[:2]]

    pos = parse_txt_array(src[1:1 + num_nodes])

    face = src[1 + num_nodes:1 + num_nodes + num_faces]
    face = parse_txt_array(face, start=1, dtype=torch.long).t().contiguous()

    data = Data(pos=pos)
    data.face = face

    return data
Exemplo n.º 3
0
def parse_sdf(src):
    src = src.split('\n')[3:]
    num_atoms, num_bonds = [int(item) for item in src[0].split()[:2]]

    atom_block = src[1:num_atoms + 1]
    pos = parse_txt_array(atom_block, end=3)
    x = torch.tensor([elems[item.split()[3]] for item in atom_block])
    x = one_hot(x, len(elems))

    bond_block = src[1 + num_atoms:1 + num_atoms + num_bonds]
    row, col = parse_txt_array(bond_block, end=2, dtype=torch.long).t() - 1
    row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0)
    edge_index = torch.stack([row, col], dim=0)
    edge_attr = parse_txt_array(bond_block, start=2, end=3) - 1
    edge_attr = torch.cat([edge_attr, edge_attr], dim=0)
    edge_index, edge_attr = coalesce(edge_index, edge_attr, num_atoms)

    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos)
    return data